Update auth.ts

This commit is contained in:
Nicolas 2024-08-12 13:42:09 -04:00
parent 25a899eae3
commit 0bd1a820ee

View File

@ -1,34 +1,41 @@
import { parseApi } from "../../src/lib/parseApi"; import { parseApi } from "../../src/lib/parseApi";
import { getRateLimiter, } from "../../src/services/rate-limiter"; import { getRateLimiter } from "../../src/services/rate-limiter";
import { AuthResponse, NotificationType, RateLimiterMode } from "../../src/types"; import {
AuthResponse,
NotificationType,
RateLimiterMode,
} from "../../src/types";
import { supabase_service } from "../../src/services/supabase"; import { supabase_service } from "../../src/services/supabase";
import { withAuth } from "../../src/lib/withAuth"; import { withAuth } from "../../src/lib/withAuth";
import { RateLimiterRedis } from "rate-limiter-flexible"; import { RateLimiterRedis } from "rate-limiter-flexible";
import { setTraceAttributes } from '@hyperdx/node-opentelemetry'; import { setTraceAttributes } from "@hyperdx/node-opentelemetry";
import { sendNotification } from "../services/notification/email_notification"; import { sendNotification } from "../services/notification/email_notification";
import { Logger } from "../lib/logger"; import { Logger } from "../lib/logger";
import { redlock } from "../../src/services/redlock"; import { redlock } from "../../src/services/redlock";
import { getValue } from "../../src/services/redis"; import { getValue } from "../../src/services/redis";
import { setValue } from "../../src/services/redis"; import { setValue } from "../../src/services/redis";
import { validate } from 'uuid'; import { validate } from "uuid";
function normalizedApiIsUuid(potentialUuid: string): boolean { function normalizedApiIsUuid(potentialUuid: string): boolean {
// Check if the string is a valid UUID // Check if the string is a valid UUID
return validate(potentialUuid); return validate(potentialUuid);
} }
export async function authenticateUser(req, res, mode?: RateLimiterMode): Promise<AuthResponse> { export async function authenticateUser(
req,
res,
mode?: RateLimiterMode
): Promise<AuthResponse> {
return withAuth(supaAuthenticateUser)(req, res, mode); return withAuth(supaAuthenticateUser)(req, res, mode);
} }
function setTrace(team_id: string, api_key: string) { function setTrace(team_id: string, api_key: string) {
try { try {
setTraceAttributes({ setTraceAttributes({
team_id, team_id,
api_key api_key,
}); });
} catch (error) { } catch (error) {
Logger.error(`Error setting trace attributes: ${error.message}`); Logger.error(`Error setting trace attributes: ${error.message}`);
} }
} }
export async function supaAuthenticateUser( export async function supaAuthenticateUser(
req, req,
@ -59,7 +66,7 @@ export async function supaAuthenticateUser(
const iptoken = incomingIP + token; const iptoken = incomingIP + token;
let rateLimiter: RateLimiterRedis; let rateLimiter: RateLimiterRedis;
let subscriptionData: { team_id: string, plan: string } | null = null; let subscriptionData: { team_id: string; plan: string } | null = null;
let normalizedApi: string; let normalizedApi: string;
let cacheKey = ""; let cacheKey = "";
@ -87,39 +94,44 @@ export async function supaAuthenticateUser(
const lock = await redlock.acquire([redLockKey], lockTTL); const lock = await redlock.acquire([redLockKey], lockTTL);
try { try {
const teamIdPriceId = await getValue(cacheKey); const teamIdPriceId = await getValue(cacheKey);
if (teamIdPriceId) { if (teamIdPriceId) {
const { team_id, price_id } = JSON.parse(teamIdPriceId); const { team_id, price_id } = JSON.parse(teamIdPriceId);
teamId = team_id; teamId = team_id;
priceId = price_id; priceId = price_id;
} } else {
else{
const { data, error } = await supabase_service.rpc( const { data, error } = await supabase_service.rpc(
'get_key_and_price_id_2', { api_key: normalizedApi } "get_key_and_price_id_2",
{ api_key: normalizedApi }
); );
if (error) { if (error) {
Logger.error(`RPC ERROR (get_key_and_price_id_2): ${error.message}`); Logger.error(
`RPC ERROR (get_key_and_price_id_2): ${error.message}`
);
return { return {
success: false, success: false,
error: "The server seems overloaded. Please contact hello@firecrawl.com if you aren't sending too many requests at once.", error:
"The server seems overloaded. Please contact hello@firecrawl.com if you aren't sending too many requests at once.",
status: 500, status: 500,
}; };
} }
if (!data || data.length === 0) { if (!data || data.length === 0) {
Logger.warn(`Error fetching api key: ${error.message} or data is empty`); Logger.warn(
`Error fetching api key: ${error.message} or data is empty`
);
// TODO: change this error code ? // TODO: change this error code ?
return { return {
success: false, success: false,
error: "Unauthorized: Invalid token", error: "Unauthorized: Invalid token",
status: 401, status: 401,
}; };
} } else {
else {
teamId = data[0].team_id; teamId = data[0].team_id;
priceId = data[0].price_id; priceId = data[0].price_id;
} }
} }
} catch (error) {
Logger.error(`Error with auth function: ${error.message}`);
} finally { } finally {
await lock.release(); await lock.release();
} }
@ -127,7 +139,6 @@ export async function supaAuthenticateUser(
Logger.error(`Error acquiring the rate limiter lock: ${error}`); Logger.error(`Error acquiring the rate limiter lock: ${error}`);
} }
// get_key_and_price_id_2 rpc definition: // get_key_and_price_id_2 rpc definition:
// create or replace function get_key_and_price_id_2(api_key uuid) // create or replace function get_key_and_price_id_2(api_key uuid)
// returns table(key uuid, team_id uuid, price_id text) as $$ // returns table(key uuid, team_id uuid, price_id text) as $$
@ -145,23 +156,34 @@ export async function supaAuthenticateUser(
// end; // end;
// $$ language plpgsql; // $$ language plpgsql;
const plan = getPlanByPriceId(priceId); const plan = getPlanByPriceId(priceId);
// HyperDX Logging // HyperDX Logging
setTrace(teamId, normalizedApi); setTrace(teamId, normalizedApi);
subscriptionData = { subscriptionData = {
team_id: teamId, team_id: teamId,
plan: plan plan: plan,
} };
switch (mode) { switch (mode) {
case RateLimiterMode.Crawl: case RateLimiterMode.Crawl:
rateLimiter = getRateLimiter(RateLimiterMode.Crawl, token, subscriptionData.plan); rateLimiter = getRateLimiter(
RateLimiterMode.Crawl,
token,
subscriptionData.plan
);
break; break;
case RateLimiterMode.Scrape: case RateLimiterMode.Scrape:
rateLimiter = getRateLimiter(RateLimiterMode.Scrape, token, subscriptionData.plan); rateLimiter = getRateLimiter(
RateLimiterMode.Scrape,
token,
subscriptionData.plan
);
break; break;
case RateLimiterMode.Search: case RateLimiterMode.Search:
rateLimiter = getRateLimiter(RateLimiterMode.Search, token, subscriptionData.plan); rateLimiter = getRateLimiter(
RateLimiterMode.Search,
token,
subscriptionData.plan
);
break; break;
case RateLimiterMode.CrawlStatus: case RateLimiterMode.CrawlStatus:
rateLimiter = getRateLimiter(RateLimiterMode.CrawlStatus, token); rateLimiter = getRateLimiter(RateLimiterMode.CrawlStatus, token);
@ -179,7 +201,8 @@ export async function supaAuthenticateUser(
} }
} }
const team_endpoint_token = token === "this_is_just_a_preview_token" ? iptoken : teamId; const team_endpoint_token =
token === "this_is_just_a_preview_token" ? iptoken : teamId;
try { try {
await rateLimiter.consume(team_endpoint_token); await rateLimiter.consume(team_endpoint_token);
@ -196,7 +219,11 @@ export async function supaAuthenticateUser(
// await sendNotification(team_id, NotificationType.RATE_LIMIT_REACHED, startDate.toISOString(), endDate.toISOString()); // await sendNotification(team_id, NotificationType.RATE_LIMIT_REACHED, startDate.toISOString(), endDate.toISOString());
// TODO: cache 429 for a few minuts // TODO: cache 429 for a few minuts
if (teamId && priceId && mode !== RateLimiterMode.Preview) { if (teamId && priceId && mode !== RateLimiterMode.Preview) {
await setValue(cacheKey, JSON.stringify({team_id: teamId, price_id: priceId}), 60 * 5); await setValue(
cacheKey,
JSON.stringify({ team_id: teamId, price_id: priceId }),
60 * 5
);
} }
return { return {
@ -208,7 +235,9 @@ export async function supaAuthenticateUser(
if ( if (
token === "this_is_just_a_preview_token" && token === "this_is_just_a_preview_token" &&
(mode === RateLimiterMode.Scrape || mode === RateLimiterMode.Preview || mode === RateLimiterMode.Search) (mode === RateLimiterMode.Scrape ||
mode === RateLimiterMode.Preview ||
mode === RateLimiterMode.Search)
) { ) {
return { success: true, team_id: "preview" }; return { success: true, team_id: "preview" };
// check the origin of the request and make sure its from firecrawl.dev // check the origin of the request and make sure its from firecrawl.dev
@ -232,8 +261,6 @@ export async function supaAuthenticateUser(
.select("*") .select("*")
.eq("key", normalizedApi); .eq("key", normalizedApi);
if (error || !data || data.length === 0) { if (error || !data || data.length === 0) {
Logger.warn(`Error fetching api key: ${error.message} or data is empty`); Logger.warn(`Error fetching api key: ${error.message} or data is empty`);
return { return {
@ -246,26 +273,30 @@ export async function supaAuthenticateUser(
subscriptionData = data[0]; subscriptionData = data[0];
} }
return { success: true, team_id: subscriptionData.team_id, plan: subscriptionData.plan ?? ""}; return {
success: true,
team_id: subscriptionData.team_id,
plan: subscriptionData.plan ?? "",
};
} }
function getPlanByPriceId(price_id: string) { function getPlanByPriceId(price_id: string) {
switch (price_id) { switch (price_id) {
case process.env.STRIPE_PRICE_ID_STARTER: case process.env.STRIPE_PRICE_ID_STARTER:
return 'starter'; return "starter";
case process.env.STRIPE_PRICE_ID_STANDARD: case process.env.STRIPE_PRICE_ID_STANDARD:
return 'standard'; return "standard";
case process.env.STRIPE_PRICE_ID_SCALE: case process.env.STRIPE_PRICE_ID_SCALE:
return 'scale'; return "scale";
case process.env.STRIPE_PRICE_ID_HOBBY: case process.env.STRIPE_PRICE_ID_HOBBY:
case process.env.STRIPE_PRICE_ID_HOBBY_YEARLY: case process.env.STRIPE_PRICE_ID_HOBBY_YEARLY:
return 'hobby'; return "hobby";
case process.env.STRIPE_PRICE_ID_STANDARD_NEW: case process.env.STRIPE_PRICE_ID_STANDARD_NEW:
case process.env.STRIPE_PRICE_ID_STANDARD_NEW_YEARLY: case process.env.STRIPE_PRICE_ID_STANDARD_NEW_YEARLY:
return 'standardnew'; return "standardnew";
case process.env.STRIPE_PRICE_ID_GROWTH: case process.env.STRIPE_PRICE_ID_GROWTH:
case process.env.STRIPE_PRICE_ID_GROWTH_YEARLY: case process.env.STRIPE_PRICE_ID_GROWTH_YEARLY:
return 'growth'; return "growth";
default: default:
return 'free'; return "free";
} }
} }