From 438ea19f16355efa30533e32e56483836f725085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerg=C5=91=20M=C3=B3ricz?= Date: Sat, 19 Apr 2025 01:35:17 -0700 Subject: [PATCH] feat(extract): add thinking tokens --- .../api/src/lib/extract/extraction-service.ts | 80 +++++++++++++++++-- apps/api/src/lib/extract/usage/llm-cost.ts | 5 ++ 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/apps/api/src/lib/extract/extraction-service.ts b/apps/api/src/lib/extract/extraction-service.ts index 6a3e6756..652cc5c0 100644 --- a/apps/api/src/lib/extract/extraction-service.ts +++ b/apps/api/src/lib/extract/extraction-service.ts @@ -27,7 +27,7 @@ import { deduplicateObjectsArray } from "./helpers/deduplicate-objs-array"; import { mergeNullValObjs } from "./helpers/merge-null-val-objs"; import { areMergeable } from "./helpers/merge-null-val-objs"; import { CUSTOM_U_TEAMS } from "./config"; -import { calculateFinalResultCost, estimateTotalCost } from "./usage/llm-cost"; +import { calculateFinalResultCost, calculateThinkingCost, estimateTotalCost } from "./usage/llm-cost"; import { analyzeSchemaAndPrompt } from "./completions/analyzeSchemaAndPrompt"; import { batchExtractPromise } from "./completions/batchExtract"; import { singleAnswerCompletion } from "./completions/singleAnswer"; @@ -168,6 +168,8 @@ export async function performExtraction( logger.error("No search results found", { query: request.prompt, }); + + const tokens_billed = 300 + calculateThinkingCost(costTracking); logJob({ job_id: extractId, success: false, @@ -181,10 +183,17 @@ export async function performExtraction( scrapeOptions: request, origin: request.origin ?? "api", num_tokens: 0, - tokens_billed: 0, + tokens_billed, sources, cost_tracking: costTracking, }); + + await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => { + logger.error( + `Failed to bill team ${teamId} for thinking tokens: ${error}`, + ); + }); + return { success: false, error: "No search results found", @@ -654,6 +663,7 @@ export async function performExtraction( multiEntityCompletions: JSON.stringify(multiEntityCompletions), multiEntitySchema: JSON.stringify(multiEntitySchema) }); + const tokens_billed = 300 + calculateThinkingCost(costTracking); logJob({ job_id: extractId, success: false, @@ -667,10 +677,15 @@ export async function performExtraction( scrapeOptions: request, origin: request.origin ?? "api", num_tokens: 0, - tokens_billed: 0, + tokens_billed, sources, cost_tracking: costTracking, }); + await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => { + logger.error( + `Failed to bill team ${teamId} for thinking tokens: ${error}`, + ); + }); return { success: false, error: @@ -755,6 +770,29 @@ export async function performExtraction( logger.debug("Scrapes finished.", { docCount: validResults.length }); } catch (error) { + const tokens_billed = 300 + calculateThinkingCost(costTracking); + logJob({ + job_id: extractId, + success: false, + message: error.message, + num_docs: 1, + docs: [], + time_taken: (new Date().getTime() - Date.now()) / 1000, + team_id: teamId, + mode: "extract", + url: request.urls?.join(", ") || "", + scrapeOptions: request, + origin: request.origin ?? "api", + num_tokens: 0, + tokens_billed, + sources, + cost_tracking: costTracking, + }); + await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => { + logger.error( + `Failed to bill team ${teamId} for thinking tokens: ${error}`, + ); + }); return { success: false, error: error.message, @@ -767,6 +805,29 @@ export async function performExtraction( if (docsMap.size == 0) { // All urls are invalid logger.error("All provided URLs are invalid!"); + const tokens_billed = 300 + calculateThinkingCost(costTracking); + await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => { + logger.error( + `Failed to bill team ${teamId} for thinking tokens: ${error}`, + ); + }); + logJob({ + job_id: extractId, + success: false, + message: "All provided URLs are invalid. Please check your input and try again.", + num_docs: 1, + docs: [], + time_taken: (new Date().getTime() - Date.now()) / 1000, + team_id: teamId, + mode: "extract", + url: request.urls?.join(", ") || "", + scrapeOptions: request, + origin: request.origin ?? "api", + num_tokens: 0, + tokens_billed, + sources, + cost_tracking: costTracking, + }); return { success: false, error: @@ -920,14 +981,14 @@ export async function performExtraction( const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0); const llmUsage = estimateTotalCost(tokenUsage); - let tokensToBill = calculateFinalResultCost(finalResult); + let tokensToBill = calculateFinalResultCost(finalResult) + calculateThinkingCost(costTracking); if (CUSTOM_U_TEAMS.includes(teamId)) { tokensToBill = 1; } // Bill team for usage - billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => { + await billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => { logger.error( `Failed to bill team ${teamId} for ${tokensToBill} tokens: ${error}`, ); @@ -996,6 +1057,12 @@ export async function performExtraction( sources, }; } catch (error) { + const tokens_billed = 300 + calculateThinkingCost(costTracking); + await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => { + logger.error( + `Failed to bill team ${teamId} for thinking tokens: ${error}`, + ); + }); await logJob({ job_id: extractId, success: false, @@ -1009,10 +1076,11 @@ export async function performExtraction( scrapeOptions: request, origin: request.origin ?? "api", num_tokens: 0, - tokens_billed: 0, + tokens_billed, sources, cost_tracking: costTracking, }); + throw error; } } diff --git a/apps/api/src/lib/extract/usage/llm-cost.ts b/apps/api/src/lib/extract/usage/llm-cost.ts index 58ab53cd..784a5749 100644 --- a/apps/api/src/lib/extract/usage/llm-cost.ts +++ b/apps/api/src/lib/extract/usage/llm-cost.ts @@ -1,5 +1,6 @@ import { TokenUsage } from "../../../controllers/v1/types"; import { logger } from "../../../lib/logger"; +import { CostTracking } from "../extraction-service"; import { modelPrices } from "./model-prices"; interface ModelPricing { @@ -11,6 +12,10 @@ interface ModelPricing { const tokenPerCharacter = 0.5; const baseTokenCost = 300; +export function calculateThinkingCost(costTracking: CostTracking): number { + return costTracking.toJSON().totalCost * 267000; +} + export function calculateFinalResultCost(data: any): number { return Math.floor( JSON.stringify(data).length / tokenPerCharacter + baseTokenCost,