From 1f6abf95e8afeea19cbee524315e92b2381e6117 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Fri, 17 Jan 2025 20:59:44 -0300 Subject: [PATCH] Nick: extract billing works --- apps/api/src/controllers/auth.ts | 6 ++-- .../api/src/lib/extract/extraction-service.ts | 30 +++++++++++++------ apps/api/src/lib/extract/usage/llm-cost.ts | 6 ++++ .../src/services/billing/credit_billing.ts | 6 +++- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/apps/api/src/controllers/auth.ts b/apps/api/src/controllers/auth.ts index 6bda3039..52a87145 100644 --- a/apps/api/src/controllers/auth.ts +++ b/apps/api/src/controllers/auth.ts @@ -77,6 +77,7 @@ export async function getACUC( api_key: string, cacheOnly = false, useCache = true, + mode?: RateLimiterMode, ): Promise { const cacheKeyACUC = `acuc_${api_key}`; @@ -93,9 +94,10 @@ export async function getACUC( let retries = 0; const maxRetries = 5; + let rpcName = mode === RateLimiterMode.Extract ? "auth_credit_usage_chunk_extract" : "auth_credit_usage_chunk_test_21_credit_pack"; while (retries < maxRetries) { ({ data, error } = await supabase_service.rpc( - "auth_credit_usage_chunk_test_21_credit_pack", + rpcName, { input_key: api_key }, { get: true }, )); @@ -203,7 +205,7 @@ export async function supaAuthenticateUser( }; } - chunk = await getACUC(normalizedApi); + chunk = await getACUC(normalizedApi,false, true, mode); if (chunk === null) { return { diff --git a/apps/api/src/lib/extract/extraction-service.ts b/apps/api/src/lib/extract/extraction-service.ts index 5d7a653f..0933d347 100644 --- a/apps/api/src/lib/extract/extraction-service.ts +++ b/apps/api/src/lib/extract/extraction-service.ts @@ -32,7 +32,7 @@ import { ExtractStep, updateExtract } from "./extract-redis"; import { deduplicateObjectsArray } from "./helpers/deduplicate-objs-array"; import { mergeNullValObjs } from "./helpers/merge-null-val-objs"; import { CUSTOM_U_TEAMS } from "./config"; -import { estimateCost, estimateTotalCost } from "./usage/llm-cost"; +import { calculateFinalResultCost, estimateCost, estimateTotalCost } from "./usage/llm-cost"; interface ExtractServiceOptions { request: ExtractRequest; @@ -50,6 +50,7 @@ interface ExtractResult { error?: string; tokenUsageBreakdown?: TokenUsage[]; llmUsage?: number; + totalUrlsScraped?: number; } async function analyzeSchemaAndPrompt( @@ -178,6 +179,7 @@ export async function performExtraction( let multiEntityCompletions: completions[] = []; let multiEntityResult: any = {}; let singleAnswerResult: any = {}; + let totalUrlsScraped = 0; // Token tracking @@ -238,6 +240,7 @@ export async function performExtraction( "No valid URLs found to scrape. Try adjusting your search criteria or including more URLs.", extractId, urlTrace: urlTraces, + totalUrlsScraped: 0 }; } @@ -334,6 +337,8 @@ export async function performExtraction( (doc): doc is Document => doc !== null, ); + totalUrlsScraped += multyEntityDocs.length; + let endScrape = Date.now(); await updateExtract(extractId, { @@ -529,6 +534,7 @@ export async function performExtraction( "An unexpected error occurred. Please contact help@firecrawl.com for help.", extractId, urlTrace: urlTraces, + totalUrlsScraped }; } } @@ -580,15 +586,17 @@ export async function performExtraction( } } - singleAnswerDocs.push( - ...results.filter((doc): doc is Document => doc !== null), - ); + const validResults = results.filter((doc): doc is Document => doc !== null); + singleAnswerDocs.push(...validResults); + totalUrlsScraped += validResults.length; + } catch (error) { return { success: false, error: error.message, extractId, urlTrace: urlTraces, + totalUrlsScraped }; } @@ -600,6 +608,7 @@ export async function performExtraction( "All provided URLs are invalid. Please check your input and try again.", extractId, urlTrace: request.urlTrace ? urlTraces : undefined, + totalUrlsScraped: 0 }; } @@ -663,20 +672,23 @@ export async function performExtraction( ? await mixSchemaObjects(reqSchema, singleAnswerResult, multiEntityResult) : singleAnswerResult || multiEntityResult; + + const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0); + const llmUsage = estimateTotalCost(tokenUsage); + const tokensToBill = calculateFinalResultCost(finalResult); + let linksBilled = links.length * 5; if (CUSTOM_U_TEAMS.includes(teamId)) { linksBilled = 1; } // Bill team for usage - billTeam(teamId, subId, linksBilled).catch((error) => { + billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => { logger.error( - `Failed to bill team ${teamId} for ${linksBilled} credits: ${error}`, + `Failed to bill team ${teamId} for ${tokensToBill} tokens: ${error}`, ); }); - const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0); - const llmUsage = estimateTotalCost(tokenUsage); // Log job with token usage logJob({ @@ -710,6 +722,6 @@ export async function performExtraction( warning: undefined, // TODO FIX urlTrace: request.urlTrace ? urlTraces : undefined, llmUsage, + totalUrlsScraped }; } - diff --git a/apps/api/src/lib/extract/usage/llm-cost.ts b/apps/api/src/lib/extract/usage/llm-cost.ts index 7cba201d..6452a078 100644 --- a/apps/api/src/lib/extract/usage/llm-cost.ts +++ b/apps/api/src/lib/extract/usage/llm-cost.ts @@ -8,6 +8,12 @@ interface ModelPricing { input_cost_per_request?: number; mode: string; } +const tokenPerCharacter = 4; +const baseTokenCost = 200; + +export function calculateFinalResultCost(data: any): number { + return JSON.stringify(data).length / tokenPerCharacter + baseTokenCost; +} export function estimateTotalCost(tokenUsage: TokenUsage[]): number { return tokenUsage.reduce((total, usage) => { diff --git a/apps/api/src/services/billing/credit_billing.ts b/apps/api/src/services/billing/credit_billing.ts index 5eb541fd..4aa8b4cf 100644 --- a/apps/api/src/services/billing/credit_billing.ts +++ b/apps/api/src/services/billing/credit_billing.ts @@ -22,12 +22,14 @@ export async function billTeam( subscription_id: string | null | undefined, credits: number, logger?: Logger, + is_extract: boolean = false, ) { return withAuth(supaBillTeam, { success: true, message: "No DB, bypassed." })( team_id, subscription_id, credits, logger, + is_extract, ); } export async function supaBillTeam( @@ -35,6 +37,7 @@ export async function supaBillTeam( subscription_id: string | null | undefined, credits: number, __logger?: Logger, + is_extract: boolean = false, ) { const _logger = (__logger ?? logger).child({ module: "credit_billing", @@ -49,11 +52,12 @@ export async function supaBillTeam( credits, }); - const { data, error } = await supabase_service.rpc("bill_team", { + const { data, error } = await supabase_service.rpc("bill_team_w_extract", { _team_id: team_id, sub_id: subscription_id ?? null, fetch_subscription: subscription_id === undefined, credits, + is_extract, }); if (error) {