This commit is contained in:
rafaelmmiller 2025-04-03 15:31:05 -03:00
parent fd58e782b1
commit f189dfb9d7
3 changed files with 89 additions and 29 deletions

View File

@ -1,5 +1,5 @@
import { logger } from "../../../lib/logger"; import { logger } from "../../../lib/logger";
import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract"; import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { buildDocument } from "../build-document"; import { buildDocument } from "../build-document";
import { ExtractResponse, TokenUsage } from "../../../controllers/v1/types"; import { ExtractResponse, TokenUsage } from "../../../controllers/v1/types";
import { Document } from "../../../controllers/v1/types"; import { Document } from "../../../controllers/v1/types";
@ -10,6 +10,7 @@ import {
import { getModel } from "../../generic-ai"; import { getModel } from "../../generic-ai";
import fs from "fs/promises"; import fs from "fs/promises";
import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape";
/** /**
* Batch extract information from a list of URLs using a multi-entity schema. * Batch extract information from a list of URLs using a multi-entity schema.
* @param multiEntitySchema - The schema for the multi-entity extraction * @param multiEntitySchema - The schema for the multi-entity extraction
@ -26,14 +27,13 @@ export async function batchExtractPromise(
systemPrompt: string, systemPrompt: string,
doc: Document, doc: Document,
): Promise<{ ): Promise<{
extract: any; extract: any; // array of extracted data
numTokens: number; numTokens: number;
totalUsage: TokenUsage; totalUsage: TokenUsage;
warning?: string; warning?: string;
sources: string[]; sources: string[];
}> { }> {
const gemini = getModel("gemini-2.0-flash", "google"); const generationOptions: GenerateCompletionsOptions = {
const completion = await generateCompletions({
logger: logger.child({ logger: logger.child({
method: "extractService/generateCompletions", method: "extractService/generateCompletions",
}), }),
@ -49,17 +49,30 @@ export async function batchExtractPromise(
}, },
markdown: buildDocument(doc), markdown: buildDocument(doc),
isExtractEndpoint: true, isExtractEndpoint: true,
model: gemini("gemini-2.0-flash"), model: getModel("gemini-2.0-flash", "google"),
};
const { extractedDataArray, warning } = await extractData({
extractOptions: generationOptions,
url: doc.metadata.sourceURL || doc.metadata.url || "",
}); });
await fs.writeFile( await fs.writeFile(
`logs/batchExtract-${crypto.randomUUID()}.json`, `logs/extractedDataArray-${crypto.randomUUID()}.json`,
JSON.stringify(completion, null, 2), JSON.stringify(extractedDataArray, null, 2),
); );
// TODO: fix this
return { return {
extract: completion.extract, extract: extractedDataArray,
numTokens: completion.numTokens, numTokens: 0,
totalUsage: completion.totalUsage, totalUsage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
model: "gemini-2.0-flash",
},
warning: warning,
sources: [doc.metadata.url || doc.metadata.sourceURL || ""], sources: [doc.metadata.url || doc.metadata.sourceURL || ""],
}; }
} }

View File

@ -1,9 +1,10 @@
import { logger } from "../../../lib/logger"; import { logger } from "../../../lib/logger";
import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract"; import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract";
import { buildDocument } from "../build-document"; import { buildDocument } from "../build-document";
import { Document, TokenUsage } from "../../../controllers/v1/types"; import { Document, TokenUsage } from "../../../controllers/v1/types";
import { getModel } from "../../../lib/generic-ai"; import { getModel } from "../../../lib/generic-ai";
import fs from "fs/promises"; import fs from "fs/promises";
import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape";
export async function singleAnswerCompletion({ export async function singleAnswerCompletion({
singleAnswerDocs, singleAnswerDocs,
@ -11,35 +12,71 @@ export async function singleAnswerCompletion({
links, links,
prompt, prompt,
systemPrompt, systemPrompt,
urls,
}: { }: {
singleAnswerDocs: Document[]; singleAnswerDocs: Document[];
rSchema: any; rSchema: any;
links: string[]; links: string[];
prompt: string; prompt: string;
systemPrompt: string; systemPrompt: string;
urls: string[];
}): Promise<{ }): Promise<{
extract: any; extract: any;
tokenUsage: TokenUsage; tokenUsage: TokenUsage;
sources: string[]; sources: string[];
}> { }> {
const completion = await generateCompletions({ const generationOptions: GenerateCompletionsOptions = {
logger: logger.child({ module: "extract", method: "generateCompletions" }), logger: logger.child({ module: "extract", method: "generateCompletions" }),
options: { options: {
mode: "llm", mode: "llm",
systemPrompt: systemPrompt:
(systemPrompt ? `${systemPrompt}\n` : "") + (systemPrompt ? `${systemPrompt}\n` : "") +
"Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.", "Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.",
prompt: "Today is: " + new Date().toISOString() + "\n" + prompt, prompt: "Today is: " + new Date().toISOString() + ".\n" + prompt,
schema: rSchema, schema: rSchema,
}, },
markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"), markdown: singleAnswerDocs.map((x, i) => `[ID: ${i}]` + buildDocument(x)).join("\n"),
isExtractEndpoint: true, isExtractEndpoint: true,
model: getModel("gemini-2.0-flash", "google"), model: getModel("gemini-2.0-flash", "google"),
};
const { extractedDataArray, warning } = await extractData({
extractOptions: generationOptions,
urls,
}); });
await fs.writeFile(
`logs/singleAnswer-${crypto.randomUUID()}.json`, const completion = {
JSON.stringify(completion, null, 2), extract: extractedDataArray,
); tokenUsage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
model: "gemini-2.0-flash",
},
sources: singleAnswerDocs.map(
(doc) => doc.metadata.url || doc.metadata.sourceURL || "",
),
};
// const completion = await generateCompletions({
// logger: logger.child({ module: "extract", method: "generateCompletions" }),
// options: {
// mode: "llm",
// systemPrompt:
// (systemPrompt ? `${systemPrompt}\n` : "") +
// "Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.",
// prompt: "Today is: " + new Date().toISOString() + "\n" + prompt,
// schema: rSchema,
// },
// markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"),
// isExtractEndpoint: true,
// model: getModel("gemini-2.0-flash", "google"),
// });
// await fs.writeFile(
// `logs/singleAnswer-${crypto.randomUUID()}.json`,
// JSON.stringify(completion, null, 2),
// );
return { return {
extract: completion.extract, extract: completion.extract,
tokenUsage: completion.totalUsage, tokenUsage: completion.totalUsage,

View File

@ -455,10 +455,10 @@ export async function performExtraction(
); );
// Race between timeout and completion // Race between timeout and completion
const multiEntityCompletion = (await Promise.race([ const multiEntityCompletion = await completionPromise as Awaited<ReturnType<typeof batchExtractPromise>>;
completionPromise,
timeoutPromise, // TODO: merge multiEntityCompletion.extract to fit the multiEntitySchema
])) as Awaited<ReturnType<typeof generateCompletions>>;
// Track multi-entity extraction tokens // Track multi-entity extraction tokens
if (multiEntityCompletion) { if (multiEntityCompletion) {
@ -513,14 +513,16 @@ export async function performExtraction(
return null; return null;
} }
}); });
// Wait for current chunk to complete before processing next chunk // Wait for current chunk to complete before processing next chunk
const chunkResults = await Promise.all(chunkPromises); const chunkResults = await Promise.all(chunkPromises);
const validResults = chunkResults.filter( const validResults = chunkResults.filter(
(result): result is { extract: any; url: string } => result !== null, (result): result is { extract: any; url: string } => result !== null,
); );
extractionResults.push(...validResults); extractionResults.push(...validResults);
multiEntityCompletions.push(...validResults.map((r) => r.extract)); // Merge all extracts from valid results into a single array
const extractArrays = validResults.map(r => Array.isArray(r.extract) ? r.extract : [r.extract]);
const mergedExtracts = extractArrays.flat();
multiEntityCompletions.push(...mergedExtracts);
logger.debug("All multi-entity completion chunks finished.", { logger.debug("All multi-entity completion chunks finished.", {
completionCount: multiEntityCompletions.length, completionCount: multiEntityCompletions.length,
}); });
@ -682,6 +684,7 @@ export async function performExtraction(
tokenUsage: singleAnswerTokenUsage, tokenUsage: singleAnswerTokenUsage,
sources: singleAnswerSources, sources: singleAnswerSources,
} = await singleAnswerCompletion({ } = await singleAnswerCompletion({
url: request.urls?.[0] || "",
singleAnswerDocs, singleAnswerDocs,
rSchema, rSchema,
links, links,
@ -690,6 +693,13 @@ export async function performExtraction(
}); });
logger.debug("Done generating singleAnswer completions."); logger.debug("Done generating singleAnswer completions.");
singleAnswerResult = transformArrayToObject(
rSchema,
completionResult,
);
singleAnswerResult = deduplicateObjectsArray(singleAnswerResult);
// Track single answer extraction tokens and sources // Track single answer extraction tokens and sources
if (completionResult) { if (completionResult) {
tokenUsage.push(singleAnswerTokenUsage); tokenUsage.push(singleAnswerTokenUsage);