From 0891a393d7646892fb1d17ee88b6a8ab5e705a69 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 26 Nov 2024 16:31:07 +0800 Subject: [PATCH] Let ThreadPool exit gracefully. (#3653) ### What problem does this PR solve? #3646 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/crawler.py | 4 ---- graphrag/index.py | 34 +++++++++++++++++----------------- graphrag/mind_map_extractor.py | 34 +++++++++++++++++----------------- rag/llm/chat_model.py | 2 +- rag/svr/task_executor.py | 2 ++ 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/agent/component/crawler.py b/agent/component/crawler.py index 151942b65..b500d7852 100644 --- a/agent/component/crawler.py +++ b/agent/component/crawler.py @@ -65,7 +65,3 @@ class Crawler(ComponentBase, ABC): elif self._param.extract_type == 'content': result.extracted_content return result.markdown - - - - \ No newline at end of file diff --git a/graphrag/index.py b/graphrag/index.py index b0129a985..89e332cd0 100644 --- a/graphrag/index.py +++ b/graphrag/index.py @@ -64,27 +64,27 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en BATCH_SIZE=4 texts, graphs = [], [] cnt = 0 - threads = [] max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50)) - exe = ThreadPoolExecutor(max_workers=max_workers) - for i in range(len(chunks)): - tkn_cnt = num_tokens_from_string(chunks[i]) - if cnt+tkn_cnt >= left_token_count and texts: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + threads = [] + for i in range(len(chunks)): + tkn_cnt = num_tokens_from_string(chunks[i]) + if cnt+tkn_cnt >= left_token_count and texts: + for b in range(0, len(texts), BATCH_SIZE): + threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) + texts = [] + cnt = 0 + texts.append(chunks[i]) + cnt += tkn_cnt + if texts: for b in range(0, len(texts), BATCH_SIZE): threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) - texts = [] - cnt = 0 - texts.append(chunks[i]) - cnt += tkn_cnt - if texts: - for b in range(0, len(texts), BATCH_SIZE): - threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback)) - callback(0.5, "Extracting entities.") - graphs = [] - for i, _ in enumerate(threads): - graphs.append(_.result().output) - callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") + callback(0.5, "Extracting entities.") + graphs = [] + for i, _ in enumerate(threads): + graphs.append(_.result().output) + callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") graph = reduce(graph_merge, graphs) if graphs else nx.Graph() er = EntityResolution(llm_bdl) diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py index a472f8946..74d396dbf 100644 --- a/graphrag/mind_map_extractor.py +++ b/graphrag/mind_map_extractor.py @@ -88,26 +88,26 @@ class MindMapExtractor: prompt_variables = {} try: - max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) - exe = ThreadPoolExecutor(max_workers=max_workers) - threads = [] - token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) - texts = [] res = [] - cnt = 0 - for i in range(len(sections)): - section_cnt = num_tokens_from_string(sections[i]) - if cnt + section_cnt >= token_count and texts: + max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12)) + with ThreadPoolExecutor(max_workers=max_workers) as exe: + threads = [] + token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) + texts = [] + cnt = 0 + for i in range(len(sections)): + section_cnt = num_tokens_from_string(sections[i]) + if cnt + section_cnt >= token_count and texts: + threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) + texts = [] + cnt = 0 + texts.append(sections[i]) + cnt += section_cnt + if texts: threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) - texts = [] - cnt = 0 - texts.append(sections[i]) - cnt += section_cnt - if texts: - threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables)) - for i, _ in enumerate(threads): - res.append(_.result()) + for i, _ in enumerate(threads): + res.append(_.result()) if not res: return MindMapResult(output={"id": "root", "children": []}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 084bbc55e..9dea59a72 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -366,7 +366,7 @@ class OllamaChat(Base): keep_alive=-1 ) ans = response["message"]["content"].strip() - return ans, response["eval_count"] + response.get("prompt_eval_count", 0) + return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0) except Exception as e: return "**ERROR**: " + str(e), 0 diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 2b3314632..9829e68cb 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -492,6 +492,7 @@ def report_status(): logging.exception("report_status got exception") time.sleep(30) + def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot, snapshot_id: int, dump_full: bool): msg = "" if dump_full: @@ -508,6 +509,7 @@ def analyze_heap(snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapsho msg += '\n'.join(stat.traceback.format()) logging.info(msg) + def main(): settings.init_settings() background_thread = threading.Thread(target=report_status)