diff --git a/apps/api/src/lib/extract/completions/batchExtract.ts b/apps/api/src/lib/extract/completions/batchExtract.ts index 72f1a32d..6db4e826 100644 --- a/apps/api/src/lib/extract/completions/batchExtract.ts +++ b/apps/api/src/lib/extract/completions/batchExtract.ts @@ -1,5 +1,5 @@ import { logger } from "../../../lib/logger"; -import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract"; +import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract"; import { buildDocument } from "../build-document"; import { ExtractResponse, TokenUsage } from "../../../controllers/v1/types"; import { Document } from "../../../controllers/v1/types"; @@ -10,6 +10,7 @@ import { import { getModel } from "../../generic-ai"; import fs from "fs/promises"; +import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape"; /** * Batch extract information from a list of URLs using a multi-entity schema. * @param multiEntitySchema - The schema for the multi-entity extraction @@ -26,14 +27,13 @@ export async function batchExtractPromise( systemPrompt: string, doc: Document, ): Promise<{ - extract: any; + extract: any; // array of extracted data numTokens: number; totalUsage: TokenUsage; warning?: string; sources: string[]; }> { - const gemini = getModel("gemini-2.0-flash", "google"); - const completion = await generateCompletions({ + const generationOptions: GenerateCompletionsOptions = { logger: logger.child({ method: "extractService/generateCompletions", }), @@ -49,17 +49,30 @@ export async function batchExtractPromise( }, markdown: buildDocument(doc), isExtractEndpoint: true, - model: gemini("gemini-2.0-flash"), + model: getModel("gemini-2.0-flash", "google"), + }; + + const { extractedDataArray, warning } = await extractData({ + extractOptions: generationOptions, + url: doc.metadata.sourceURL || doc.metadata.url || "", }); + await fs.writeFile( - `logs/batchExtract-${crypto.randomUUID()}.json`, - JSON.stringify(completion, null, 2), + `logs/extractedDataArray-${crypto.randomUUID()}.json`, + JSON.stringify(extractedDataArray, null, 2), ); + // TODO: fix this return { - extract: completion.extract, - numTokens: completion.numTokens, - totalUsage: completion.totalUsage, + extract: extractedDataArray, + numTokens: 0, + totalUsage: { + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + model: "gemini-2.0-flash", + }, + warning: warning, sources: [doc.metadata.url || doc.metadata.sourceURL || ""], - }; + } } diff --git a/apps/api/src/lib/extract/completions/singleAnswer.ts b/apps/api/src/lib/extract/completions/singleAnswer.ts index f469195e..3205596f 100644 --- a/apps/api/src/lib/extract/completions/singleAnswer.ts +++ b/apps/api/src/lib/extract/completions/singleAnswer.ts @@ -1,9 +1,10 @@ import { logger } from "../../../lib/logger"; -import { generateCompletions } from "../../../scraper/scrapeURL/transformers/llmExtract"; +import { generateCompletions, GenerateCompletionsOptions } from "../../../scraper/scrapeURL/transformers/llmExtract"; import { buildDocument } from "../build-document"; import { Document, TokenUsage } from "../../../controllers/v1/types"; import { getModel } from "../../../lib/generic-ai"; import fs from "fs/promises"; +import { extractData } from "../../../scraper/scrapeURL/lib/extractSmartScrape"; export async function singleAnswerCompletion({ singleAnswerDocs, @@ -11,35 +12,71 @@ export async function singleAnswerCompletion({ links, prompt, systemPrompt, + urls, }: { singleAnswerDocs: Document[]; rSchema: any; links: string[]; prompt: string; systemPrompt: string; + urls: string[]; }): Promise<{ extract: any; tokenUsage: TokenUsage; sources: string[]; }> { - const completion = await generateCompletions({ + const generationOptions: GenerateCompletionsOptions = { logger: logger.child({ module: "extract", method: "generateCompletions" }), options: { mode: "llm", systemPrompt: (systemPrompt ? `${systemPrompt}\n` : "") + "Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.", - prompt: "Today is: " + new Date().toISOString() + "\n" + prompt, - schema: rSchema, - }, - markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"), - isExtractEndpoint: true, - model: getModel("gemini-2.0-flash", "google"), + prompt: "Today is: " + new Date().toISOString() + ".\n" + prompt, + schema: rSchema, + }, + markdown: singleAnswerDocs.map((x, i) => `[ID: ${i}]` + buildDocument(x)).join("\n"), + isExtractEndpoint: true, + model: getModel("gemini-2.0-flash", "google"), + }; + + const { extractedDataArray, warning } = await extractData({ + extractOptions: generationOptions, + urls, }); - await fs.writeFile( - `logs/singleAnswer-${crypto.randomUUID()}.json`, - JSON.stringify(completion, null, 2), - ); + + const completion = { + extract: extractedDataArray, + tokenUsage: { + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + model: "gemini-2.0-flash", + }, + sources: singleAnswerDocs.map( + (doc) => doc.metadata.url || doc.metadata.sourceURL || "", + ), + }; + + + // const completion = await generateCompletions({ + // logger: logger.child({ module: "extract", method: "generateCompletions" }), + // options: { + // mode: "llm", + // systemPrompt: + // (systemPrompt ? `${systemPrompt}\n` : "") + + // "Always prioritize using the provided content to answer the question. Do not make up an answer. Do not hallucinate. In case you can't find the information and the string is required, instead of 'N/A' or 'Not speficied', return an empty string: '', if it's not a string and you can't find the information, return null. Be concise and follow the schema always if provided.", + // prompt: "Today is: " + new Date().toISOString() + "\n" + prompt, + // schema: rSchema, + // }, + // markdown: singleAnswerDocs.map((x) => buildDocument(x)).join("\n"), + // isExtractEndpoint: true, + // model: getModel("gemini-2.0-flash", "google"), + // }); + // await fs.writeFile( + // `logs/singleAnswer-${crypto.randomUUID()}.json`, + // JSON.stringify(completion, null, 2), + // ); return { extract: completion.extract, tokenUsage: completion.totalUsage, diff --git a/apps/api/src/lib/extract/extraction-service.ts b/apps/api/src/lib/extract/extraction-service.ts index 36631a31..e9c33c9a 100644 --- a/apps/api/src/lib/extract/extraction-service.ts +++ b/apps/api/src/lib/extract/extraction-service.ts @@ -455,10 +455,10 @@ export async function performExtraction( ); // Race between timeout and completion - const multiEntityCompletion = (await Promise.race([ - completionPromise, - timeoutPromise, - ])) as Awaited>; + const multiEntityCompletion = await completionPromise as Awaited>; + + // TODO: merge multiEntityCompletion.extract to fit the multiEntitySchema + // Track multi-entity extraction tokens if (multiEntityCompletion) { @@ -513,14 +513,16 @@ export async function performExtraction( return null; } }); - // Wait for current chunk to complete before processing next chunk const chunkResults = await Promise.all(chunkPromises); const validResults = chunkResults.filter( (result): result is { extract: any; url: string } => result !== null, ); extractionResults.push(...validResults); - multiEntityCompletions.push(...validResults.map((r) => r.extract)); + // Merge all extracts from valid results into a single array + const extractArrays = validResults.map(r => Array.isArray(r.extract) ? r.extract : [r.extract]); + const mergedExtracts = extractArrays.flat(); + multiEntityCompletions.push(...mergedExtracts); logger.debug("All multi-entity completion chunks finished.", { completionCount: multiEntityCompletions.length, }); @@ -682,6 +684,7 @@ export async function performExtraction( tokenUsage: singleAnswerTokenUsage, sources: singleAnswerSources, } = await singleAnswerCompletion({ + url: request.urls?.[0] || "", singleAnswerDocs, rSchema, links, @@ -690,6 +693,13 @@ export async function performExtraction( }); logger.debug("Done generating singleAnswer completions."); + singleAnswerResult = transformArrayToObject( + rSchema, + completionResult, + ); + + singleAnswerResult = deduplicateObjectsArray(singleAnswerResult); + // Track single answer extraction tokens and sources if (completionResult) { tokenUsage.push(singleAnswerTokenUsage);