This commit is contained in:
rafaelmmiller 2025-04-03 15:31:05 -03:00
parent fd58e782b1
commit f189dfb9d7
3 changed files with 89 additions and 29 deletions

View File

@ -1,5 +1,5 @@
import { logger } from "../../../lib/logger";
import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { buildDocument } from "../build-document";
import { ExtractResponse, TokenUsage } from "../../../controllers/v1/types";
import { Document } from "../../../controllers/v1/types";
@ -10,6 +10,7 @@ import {
import { getModel } from "../../generic-ai";
import fs from "fs/promises";
import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape";
/**
* Batch extract information from a list of URLs using a multi-entity schema.
* @param multiEntitySchema - The schema for the multi-entity extraction
@ -26,14 +27,13 @@ export async function batchExtractPromise(
systemPrompt: string,
doc: Document,
): Promise<{
extract: any;
extract: any; // array of extracted data
numTokens: number;
totalUsage: TokenUsage;
warning?: string;
sources: string[];
}> {
const gemini = getModel("gemini-2.0-flash", "google");
const completion = await generateCompletions({
const generationOptions: GenerateCompletionsOptions = {
logger: logger.child({
method: "extractService/generateCompletions",
}),
@ -49,17 +49,30 @@ export async function batchExtractPromise(
},
markdown: buildDocument(doc),
isExtractEndpoint: true,
model: gemini("gemini-2.0-flash"),
model: getModel("gemini-2.0-flash", "google"),
};
const { extractedDataArray, warning } = await extractData({
extractOptions: generationOptions,
url: doc.metadata.sourceURL || doc.metadata.url || "",
});
await fs.writeFile(
`logs/batchExtract-${crypto.randomUUID()}.json`,
JSON.stringify(completion, null, 2),
`logs/extractedDataArray-${crypto.randomUUID()}.json`,
JSON.stringify(extractedDataArray, null, 2),
);
// TODO: fix this
return {
extract: completion.extract,
numTokens: completion.numTokens,
totalUsage: completion.totalUsage,
extract: extractedDataArray,
numTokens: 0,
totalUsage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
model: "gemini-2.0-flash",
},
warning: warning,
sources: [doc.metadata.url || doc.metadata.sourceURL || ""],
};
}
}

View File

@ -1,9 +1,10 @@
import { logger } from "../../../lib/logger";
import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { buildDocument } from "../build-document";
import { Document, TokenUsage } from "../../../controllers/v1/types";
import { getModel } from "../../../lib/generic-ai";
import fs from "fs/promises";
import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape";
export async function singleAnswerCompletion({
singleAnswerDocs,
@ -11,35 +12,71 @@ export async function singleAnswerCompletion({
links,
prompt,
systemPrompt,
urls,
}: {
singleAnswerDocs: Document[];
rSchema: any;
links: string[];
prompt: string;
systemPrompt: string;
urls: string[];
}): Promise<{
extract: any;
tokenUsage: TokenUsage;
sources: string[];
}> {
const completion = await generateCompletions({
const generationOptions: GenerateCompletionsOptions = {
logger: logger.child({ module: "extract", method: "generateCompletions" }),
options: {
mode: "llm",
systemPrompt:
(systemPrompt ? `${systemPrompt}\n` : "") +
"Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.",
prompt: "Today is: " + new Date().toISOString() + "\n" + prompt,
schema: rSchema,
},
markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"),
isExtractEndpoint: true,
model: getModel("gemini-2.0-flash", "google"),
prompt: "Today is: " + new Date().toISOString() + ".\n" + prompt,
schema: rSchema,
},
markdown: singleAnswerDocs.map((x, i) => `[ID: ${i}]` + buildDocument(x)).join("\n"),
isExtractEndpoint: true,
model: getModel("gemini-2.0-flash", "google"),
};
const { extractedDataArray, warning } = await extractData({
extractOptions: generationOptions,
urls,
});
await fs.writeFile(
`logs/singleAnswer-${crypto.randomUUID()}.json`,
JSON.stringify(completion, null, 2),
);
const completion = {
extract: extractedDataArray,
tokenUsage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
model: "gemini-2.0-flash",
},
sources: singleAnswerDocs.map(
(doc) => doc.metadata.url || doc.metadata.sourceURL || "",
),
};
// const completion = await generateCompletions({
// logger: logger.child({ module: "extract", method: "generateCompletions" }),
// options: {
// mode: "llm",
// systemPrompt:
// (systemPrompt ? `${systemPrompt}\n` : "") +
// "Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.",
// prompt: "Today is: " + new Date().toISOString() + "\n" + prompt,
// schema: rSchema,
// },
// markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"),
// isExtractEndpoint: true,
// model: getModel("gemini-2.0-flash", "google"),
// });
// await fs.writeFile(
// `logs/singleAnswer-${crypto.randomUUID()}.json`,
// JSON.stringify(completion, null, 2),
// );
return {
extract: completion.extract,
tokenUsage: completion.totalUsage,

View File

@ -455,10 +455,10 @@ export async function performExtraction(
);
// Race between timeout and completion
const multiEntityCompletion = (await Promise.race([
completionPromise,
timeoutPromise,
])) as Awaited<ReturnType<typeof generateCompletions>>;
const multiEntityCompletion = await completionPromise as Awaited<ReturnType<typeof batchExtractPromise>>;
// TODO: merge multiEntityCompletion.extract to fit the multiEntitySchema
// Track multi-entity extraction tokens
if (multiEntityCompletion) {
@ -513,14 +513,16 @@ export async function performExtraction(
return null;
}
});
// Wait for current chunk to complete before processing next chunk
const chunkResults = await Promise.all(chunkPromises);
const validResults = chunkResults.filter(
(result): result is { extract: any; url: string } => result !== null,
);
extractionResults.push(...validResults);
multiEntityCompletions.push(...validResults.map((r) => r.extract));
// Merge all extracts from valid results into a single array
const extractArrays = validResults.map(r => Array.isArray(r.extract) ? r.extract : [r.extract]);
const mergedExtracts = extractArrays.flat();
multiEntityCompletions.push(...mergedExtracts);
logger.debug("All multi-entity completion chunks finished.", {
completionCount: multiEntityCompletions.length,
});
@ -682,6 +684,7 @@ export async function performExtraction(
tokenUsage: singleAnswerTokenUsage,
sources: singleAnswerSources,
} = await singleAnswerCompletion({
url: request.urls?.[0] || "",
singleAnswerDocs,
rSchema,
links,
@ -690,6 +693,13 @@ export async function performExtraction(
});
logger.debug("Done generating singleAnswer completions.");
singleAnswerResult = transformArrayToObject(
rSchema,
completionResult,
);
singleAnswerResult = deduplicateObjectsArray(singleAnswerResult);
// Track single answer extraction tokens and sources
if (completionResult) {
tokenUsage.push(singleAnswerTokenUsage);