feat(extract): cost limit (#1473)

This commit is contained in:
Gergő Móricz 2025-04-17 21:44:28 +02:00 committed by GitHub
parent 7df557e59c
commit 9bea877eb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 38 additions and 11 deletions

View File

@ -75,11 +75,11 @@ export async function setCachedACUC(
const mockPreviewACUC: (team_id: string, is_extract: boolean) => AuthCreditUsageChunk = (team_id, is_extract) => ({ const mockPreviewACUC: (team_id: string, is_extract: boolean) => AuthCreditUsageChunk = (team_id, is_extract) => ({
api_key: "preview", api_key: "preview",
team_id, team_id,
sub_id: "bypass", sub_id: null,
sub_current_period_start: new Date().toISOString(), sub_current_period_start: null,
sub_current_period_end: new Date(new Date().getTime() + 30 * 24 * 60 * 60 * 1000).toISOString(), sub_current_period_end: null,
sub_user_id: "bypass", sub_user_id: null,
price_id: "bypass", price_id: null,
rate_limits: { rate_limits: {
crawl: 2, crawl: 2,
scrape: 10, scrape: 10,

View File

@ -32,7 +32,7 @@ export async function analyzeSchemaAndPrompt(
const schemaString = JSON.stringify(schema); const schemaString = JSON.stringify(schema);
const model = getModel("gpt-4o"); const model = getModel("gpt-4o", "openai");
const checkSchema = z const checkSchema = z
.object({ .object({

View File

@ -10,7 +10,7 @@ import {
buildBatchExtractSystemPrompt, buildBatchExtractSystemPrompt,
} from "../build-prompts"; } from "../build-prompts";
import { getModel } from "../../generic-ai"; import { getModel } from "../../generic-ai";
import { CostTracking } from "../extraction-service"; import { CostTracking, CostLimitExceededError } from "../extraction-service";
import fs from "fs/promises"; import fs from "fs/promises";
import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape"; import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape";
import type { Logger } from "winston"; import type { Logger } from "winston";
@ -102,6 +102,9 @@ export async function batchExtractPromise(options: BatchExtractOptions, logger:
extractedDataArray = e; extractedDataArray = e;
warning = w; warning = w;
} catch (error) { } catch (error) {
if (error instanceof CostLimitExceededError) {
throw error;
}
logger.error("extractData failed", { error }); logger.error("extractData failed", { error });
} }

View File

@ -33,7 +33,7 @@ export async function checkShouldExtract(
}, },
markdown: buildDocument(doc), markdown: buildDocument(doc),
isExtractEndpoint: true, isExtractEndpoint: true,
model: getModel("gpt-4o-mini"), model: getModel("gpt-4o-mini", "openai"),
costTrackingOptions: { costTrackingOptions: {
costTracking, costTracking,
metadata: { metadata: {

View File

@ -67,6 +67,14 @@ type completions = {
sources?: string[]; sources?: string[];
}; };
export class CostLimitExceededError extends Error {
constructor() {
super("Cost limit exceeded");
this.message = "Cost limit exceeded";
this.name = "CostLimitExceededError";
}
}
export class CostTracking { export class CostTracking {
calls: { calls: {
type: "smartScrape" | "other", type: "smartScrape" | "other",
@ -79,14 +87,21 @@ export class CostTracking {
}, },
stack: string, stack: string,
}[] = []; }[] = [];
limit: number | null = null;
constructor() {} constructor(limit: number | null = null) {
this.limit = limit;
}
public addCall(call: Omit<typeof this.calls[number], "stack">) { public addCall(call: Omit<typeof this.calls[number], "stack">) {
this.calls.push({ this.calls.push({
...call, ...call,
stack: new Error().stack!.split("\n").slice(2).join("\n"), stack: new Error().stack!.split("\n").slice(2).join("\n"),
}); });
if (this.limit !== null && this.toJSON().totalCost > this.limit) {
throw new CostLimitExceededError();
}
} }
public toJSON() { public toJSON() {
@ -115,7 +130,8 @@ export async function performExtraction(
let singleAnswerResult: any = {}; let singleAnswerResult: any = {};
let totalUrlsScraped = 0; let totalUrlsScraped = 0;
let sources: Record<string, string[]> = {}; let sources: Record<string, string[]> = {};
let costTracking: CostTracking = new CostTracking();
let costTracking = new CostTracking(subId ? null : 1.5);
let log = { let log = {
extractId, extractId,
@ -532,6 +548,10 @@ export async function performExtraction(
return null; return null;
} catch (error) { } catch (error) {
if (error instanceof CostLimitExceededError) {
throw error;
}
logger.error(`Failed to process document.`, { logger.error(`Failed to process document.`, {
error, error,
url: doc.metadata.url ?? doc.metadata.sourceURL!, url: doc.metadata.url ?? doc.metadata.sourceURL!,

View File

@ -10,7 +10,7 @@ import { parseMarkdown } from "../../../lib/html-to-markdown";
import { getModel } from "../../../lib/generic-ai"; import { getModel } from "../../../lib/generic-ai";
import { TokenUsage } from "../../../controllers/v1/types"; import { TokenUsage } from "../../../controllers/v1/types";
import type { SmartScrapeResult } from "./smartScrape"; import type { SmartScrapeResult } from "./smartScrape";
import { CostTracking } from "../../../lib/extract/extraction-service"; import { CostLimitExceededError, CostTracking } from "../../../lib/extract/extraction-service";
const commonSmartScrapeProperties = { const commonSmartScrapeProperties = {
shouldUseSmartscrape: { shouldUseSmartscrape: {
type: "boolean", type: "boolean",
@ -282,6 +282,10 @@ export async function extractData({
warning = w; warning = w;
totalUsage = t; totalUsage = t;
} catch (error) { } catch (error) {
if (error instanceof CostLimitExceededError) {
throw error;
}
logger.error( logger.error(
"failed during extractSmartScrape.ts:generateCompletions", "failed during extractSmartScrape.ts:generateCompletions",
{ error }, { error },