diff --git a/apps/api/src/controllers/auth.ts b/apps/api/src/controllers/auth.ts index d0392428..e6a16cd7 100644 --- a/apps/api/src/controllers/auth.ts +++ b/apps/api/src/controllers/auth.ts @@ -75,11 +75,11 @@ export async function setCachedACUC( const mockPreviewACUC: (team_id: string, is_extract: boolean) => AuthCreditUsageChunk = (team_id, is_extract) => ({ api_key: "preview", team_id, - sub_id: "bypass", - sub_current_period_start: new Date().toISOString(), - sub_current_period_end: new Date(new Date().getTime() + 30 * 24 * 60 * 60 * 1000).toISOString(), - sub_user_id: "bypass", - price_id: "bypass", + sub_id: null, + sub_current_period_start: null, + sub_current_period_end: null, + sub_user_id: null, + price_id: null, rate_limits: { crawl: 2, scrape: 10, diff --git a/apps/api/src/lib/extract/completions/analyzeSchemaAndPrompt.ts b/apps/api/src/lib/extract/completions/analyzeSchemaAndPrompt.ts index 2626c50b..922ebeb9 100644 --- a/apps/api/src/lib/extract/completions/analyzeSchemaAndPrompt.ts +++ b/apps/api/src/lib/extract/completions/analyzeSchemaAndPrompt.ts @@ -32,7 +32,7 @@ export async function analyzeSchemaAndPrompt( const schemaString = JSON.stringify(schema); - const model = getModel("gpt-4o"); + const model = getModel("gpt-4o", "openai"); const checkSchema = z .object({ diff --git a/apps/api/src/lib/extract/completions/batchExtract.ts b/apps/api/src/lib/extract/completions/batchExtract.ts index b7075c6e..f35f3f78 100644 --- a/apps/api/src/lib/extract/completions/batchExtract.ts +++ b/apps/api/src/lib/extract/completions/batchExtract.ts @@ -10,7 +10,7 @@ import { buildBatchExtractSystemPrompt, } from "../build-prompts"; import { getModel } from "../../generic-ai"; -import { CostTracking } from "../extraction-service"; +import { CostTracking, CostLimitExceededError } from "../extraction-service"; import fs from "fs/promises"; import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape"; import type { Logger } from "winston"; @@ -102,6 +102,9 @@ export async function batchExtractPromise(options: BatchExtractOptions, logger: extractedDataArray = e; warning = w; } catch (error) { + if (error instanceof CostLimitExceededError) { + throw error; + } logger.error("extractData failed", { error }); } diff --git a/apps/api/src/lib/extract/completions/checkShouldExtract.ts b/apps/api/src/lib/extract/completions/checkShouldExtract.ts index e2c10ade..6f678ee8 100644 --- a/apps/api/src/lib/extract/completions/checkShouldExtract.ts +++ b/apps/api/src/lib/extract/completions/checkShouldExtract.ts @@ -33,7 +33,7 @@ export async function checkShouldExtract( }, markdown: buildDocument(doc), isExtractEndpoint: true, - model: getModel("gpt-4o-mini"), + model: getModel("gpt-4o-mini", "openai"), costTrackingOptions: { costTracking, metadata: { diff --git a/apps/api/src/lib/extract/extraction-service.ts b/apps/api/src/lib/extract/extraction-service.ts index 523c9c8c..6a3e6756 100644 --- a/apps/api/src/lib/extract/extraction-service.ts +++ b/apps/api/src/lib/extract/extraction-service.ts @@ -67,6 +67,14 @@ type completions = { sources?: string[]; }; +export class CostLimitExceededError extends Error { + constructor() { + super("Cost limit exceeded"); + this.message = "Cost limit exceeded"; + this.name = "CostLimitExceededError"; + } +} + export class CostTracking { calls: { type: "smartScrape" | "other", @@ -79,14 +87,21 @@ export class CostTracking { }, stack: string, }[] = []; + limit: number | null = null; - constructor() {} + constructor(limit: number | null = null) { + this.limit = limit; + } public addCall(call: Omit) { this.calls.push({ ...call, stack: new Error().stack!.split("\n").slice(2).join("\n"), }); + + if (this.limit !== null && this.toJSON().totalCost > this.limit) { + throw new CostLimitExceededError(); + } } public toJSON() { @@ -115,7 +130,8 @@ export async function performExtraction( let singleAnswerResult: any = {}; let totalUrlsScraped = 0; let sources: Record = {}; - let costTracking: CostTracking = new CostTracking(); + + let costTracking = new CostTracking(subId ? null : 1.5); let log = { extractId, @@ -532,6 +548,10 @@ export async function performExtraction( return null; } catch (error) { + if (error instanceof CostLimitExceededError) { + throw error; + } + logger.error(`Failed to process document.`, { error, url: doc.metadata.url ?? doc.metadata.sourceURL!, diff --git a/apps/api/src/scraper/scrapeURL/lib/extractSmartScrape.ts b/apps/api/src/scraper/scrapeURL/lib/extractSmartScrape.ts index 52ad2fa1..2dffe047 100644 --- a/apps/api/src/scraper/scrapeURL/lib/extractSmartScrape.ts +++ b/apps/api/src/scraper/scrapeURL/lib/extractSmartScrape.ts @@ -10,7 +10,7 @@ import { parseMarkdown } from "../../../lib/html-to-markdown"; import { getModel } from "../../../lib/generic-ai"; import { TokenUsage } from "../../../controllers/v1/types"; import type { SmartScrapeResult } from "./smartScrape"; -import { CostTracking } from "../../../lib/extract/extraction-service"; +import { CostLimitExceededError, CostTracking } from "../../../lib/extract/extraction-service"; const commonSmartScrapeProperties = { shouldUseSmartscrape: { type: "boolean", @@ -282,6 +282,10 @@ export async function extractData({ warning = w; totalUsage = t; } catch (error) { + if (error instanceof CostLimitExceededError) { + throw error; + } + logger.error( "failed during extractSmartScrape.ts:generateCompletions", { error },