feat(extract): add thinking tokens

This commit is contained in:
Gergő Móricz 2025-04-19 01:35:17 -07:00
parent 653a0207c3
commit 438ea19f16
2 changed files with 79 additions and 6 deletions

View File

@ -27,7 +27,7 @@ import { deduplicateObjectsArray } from "./helpers/deduplicate-objs-array";
import { mergeNullValObjs } from "./helpers/merge-null-val-objs"; import { mergeNullValObjs } from "./helpers/merge-null-val-objs";
import { areMergeable } from "./helpers/merge-null-val-objs"; import { areMergeable } from "./helpers/merge-null-val-objs";
import { CUSTOM_U_TEAMS } from "./config"; import { CUSTOM_U_TEAMS } from "./config";
import { calculateFinalResultCost, estimateTotalCost } from "./usage/llm-cost"; import { calculateFinalResultCost, calculateThinkingCost, estimateTotalCost } from "./usage/llm-cost";
import { analyzeSchemaAndPrompt } from "./completions/analyzeSchemaAndPrompt"; import { analyzeSchemaAndPrompt } from "./completions/analyzeSchemaAndPrompt";
import { batchExtractPromise } from "./completions/batchExtract"; import { batchExtractPromise } from "./completions/batchExtract";
import { singleAnswerCompletion } from "./completions/singleAnswer"; import { singleAnswerCompletion } from "./completions/singleAnswer";
@ -168,6 +168,8 @@ export async function performExtraction(
logger.error("No search results found", { logger.error("No search results found", {
query: request.prompt, query: request.prompt,
}); });
const tokens_billed = 300 + calculateThinkingCost(costTracking);
logJob({ logJob({
job_id: extractId, job_id: extractId,
success: false, success: false,
@ -181,10 +183,17 @@ export async function performExtraction(
scrapeOptions: request, scrapeOptions: request,
origin: request.origin ?? "api", origin: request.origin ?? "api",
num_tokens: 0, num_tokens: 0,
tokens_billed: 0, tokens_billed,
sources, sources,
cost_tracking: costTracking, cost_tracking: costTracking,
}); });
await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for thinking tokens: ${error}`,
);
});
return { return {
success: false, success: false,
error: "No search results found", error: "No search results found",
@ -654,6 +663,7 @@ export async function performExtraction(
multiEntityCompletions: JSON.stringify(multiEntityCompletions), multiEntityCompletions: JSON.stringify(multiEntityCompletions),
multiEntitySchema: JSON.stringify(multiEntitySchema) multiEntitySchema: JSON.stringify(multiEntitySchema)
}); });
const tokens_billed = 300 + calculateThinkingCost(costTracking);
logJob({ logJob({
job_id: extractId, job_id: extractId,
success: false, success: false,
@ -667,10 +677,15 @@ export async function performExtraction(
scrapeOptions: request, scrapeOptions: request,
origin: request.origin ?? "api", origin: request.origin ?? "api",
num_tokens: 0, num_tokens: 0,
tokens_billed: 0, tokens_billed,
sources, sources,
cost_tracking: costTracking, cost_tracking: costTracking,
}); });
await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for thinking tokens: ${error}`,
);
});
return { return {
success: false, success: false,
error: error:
@ -755,6 +770,29 @@ export async function performExtraction(
logger.debug("Scrapes finished.", { docCount: validResults.length }); logger.debug("Scrapes finished.", { docCount: validResults.length });
} catch (error) { } catch (error) {
const tokens_billed = 300 + calculateThinkingCost(costTracking);
logJob({
job_id: extractId,
success: false,
message: error.message,
num_docs: 1,
docs: [],
time_taken: (new Date().getTime() - Date.now()) / 1000,
team_id: teamId,
mode: "extract",
url: request.urls?.join(", ") || "",
scrapeOptions: request,
origin: request.origin ?? "api",
num_tokens: 0,
tokens_billed,
sources,
cost_tracking: costTracking,
});
await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for thinking tokens: ${error}`,
);
});
return { return {
success: false, success: false,
error: error.message, error: error.message,
@ -767,6 +805,29 @@ export async function performExtraction(
if (docsMap.size == 0) { if (docsMap.size == 0) {
// All urls are invalid // All urls are invalid
logger.error("All provided URLs are invalid!"); logger.error("All provided URLs are invalid!");
const tokens_billed = 300 + calculateThinkingCost(costTracking);
await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for thinking tokens: ${error}`,
);
});
logJob({
job_id: extractId,
success: false,
message: "All provided URLs are invalid. Please check your input and try again.",
num_docs: 1,
docs: [],
time_taken: (new Date().getTime() - Date.now()) / 1000,
team_id: teamId,
mode: "extract",
url: request.urls?.join(", ") || "",
scrapeOptions: request,
origin: request.origin ?? "api",
num_tokens: 0,
tokens_billed,
sources,
cost_tracking: costTracking,
});
return { return {
success: false, success: false,
error: error:
@ -920,14 +981,14 @@ export async function performExtraction(
const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0); const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0);
const llmUsage = estimateTotalCost(tokenUsage); const llmUsage = estimateTotalCost(tokenUsage);
let tokensToBill = calculateFinalResultCost(finalResult); let tokensToBill = calculateFinalResultCost(finalResult) + calculateThinkingCost(costTracking);
if (CUSTOM_U_TEAMS.includes(teamId)) { if (CUSTOM_U_TEAMS.includes(teamId)) {
tokensToBill = 1; tokensToBill = 1;
} }
// Bill team for usage // Bill team for usage
billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => { await billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => {
logger.error( logger.error(
`Failed to bill team ${teamId} for ${tokensToBill} tokens: ${error}`, `Failed to bill team ${teamId} for ${tokensToBill} tokens: ${error}`,
); );
@ -996,6 +1057,12 @@ export async function performExtraction(
sources, sources,
}; };
} catch (error) { } catch (error) {
const tokens_billed = 300 + calculateThinkingCost(costTracking);
await billTeam(teamId, subId, tokens_billed, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for thinking tokens: ${error}`,
);
});
await logJob({ await logJob({
job_id: extractId, job_id: extractId,
success: false, success: false,
@ -1009,10 +1076,11 @@ export async function performExtraction(
scrapeOptions: request, scrapeOptions: request,
origin: request.origin ?? "api", origin: request.origin ?? "api",
num_tokens: 0, num_tokens: 0,
tokens_billed: 0, tokens_billed,
sources, sources,
cost_tracking: costTracking, cost_tracking: costTracking,
}); });
throw error; throw error;
} }
} }

View File

@ -1,5 +1,6 @@
import { TokenUsage } from "../../../controllers/v1/types"; import { TokenUsage } from "../../../controllers/v1/types";
import { logger } from "../../../lib/logger"; import { logger } from "../../../lib/logger";
import { CostTracking } from "../extraction-service";
import { modelPrices } from "./model-prices"; import { modelPrices } from "./model-prices";
interface ModelPricing { interface ModelPricing {
@ -11,6 +12,10 @@ interface ModelPricing {
const tokenPerCharacter = 0.5; const tokenPerCharacter = 0.5;
const baseTokenCost = 300; const baseTokenCost = 300;
export function calculateThinkingCost(costTracking: CostTracking): number {
return costTracking.toJSON().totalCost * 267000;
}
export function calculateFinalResultCost(data: any): number { export function calculateFinalResultCost(data: any): number {
return Math.floor( return Math.floor(
JSON.stringify(data).length / tokenPerCharacter + baseTokenCost, JSON.stringify(data).length / tokenPerCharacter + baseTokenCost,