From db8f83104f2f49e2a18ac4665985f2179694933f Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 8 Aug 2024 13:56:30 +0800 Subject: [PATCH] less text, better extraction (#1869) ### What problem does this PR solve? #1861 ### Type of change - [x] Refactoring --- graphrag/index.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/graphrag/index.py b/graphrag/index.py index 92649aa1b..a8e43eea7 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -75,10 +75,11 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id) ext = GraphExtractor(llm_bdl) left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024 - left_token_count = max(llm_bdl.max_length * 0.8, left_token_count) + left_token_count = max(llm_bdl.max_length * 0.6, left_token_count) assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})" + BATCH_SIZE=1 texts, graphs = [], [] cnt = 0 threads = [] @@ -86,15 +87,15 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent for i in range(len(chunks)): tkn_cnt = num_tokens_from_string(chunks[i]) if cnt+tkn_cnt >= left_token_count and texts: - for b in range(0, len(texts), 16): - threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback)) + for b in range(0, len(texts), BATCH_SIZE): + threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) texts = [] cnt = 0 texts.append(chunks[i]) cnt += tkn_cnt if texts: - for b in range(0, len(texts), 16): - threads.append(exe.submit(ext, ["\n".join(texts[b:b+16])], {"entity_types": entity_types}, callback)) + for b in range(0, len(texts), BATCH_SIZE): + threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) callback(0.5, "Extracting entities.") graphs = []