diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index d211cf381..a3c5f5853 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -97,7 +97,7 @@ class Extractor: ): results = [] - max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50)) + max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10)) with ThreadPoolExecutor(max_workers=max_workers) as exe: threads = [] for i, (cid, ck) in enumerate(chunks): @@ -123,12 +123,21 @@ class Extractor: maybe_edges[tuple(sorted(k))].extend(v) logging.info("Inserting entities into storage...") all_entities_data = [] - for en_nm, ents in maybe_nodes.items(): - all_entities_data.append(self._merge_nodes(en_nm, ents)) + with ThreadPoolExecutor(max_workers=max_workers) as exe: + threads = [] + for en_nm, ents in maybe_nodes.items(): + threads.append( + exe.submit(self._merge_nodes, en_nm, ents)) + for t in threads: + n = t.result() + if not isinstance(n, Exception): + all_entities_data.append(n) + elif callback: + callback(msg="Knowledge graph nodes merging error: {}".format(str(n))) logging.info("Inserting relationships into storage...") all_relationships_data = [] - for (src,tgt), rels in maybe_edges.items(): + for (src, tgt), rels in maybe_edges.items(): all_relationships_data.append(self._merge_edges(src, tgt, rels)) if not len(all_entities_data) and not len(all_relationships_data): @@ -167,17 +176,20 @@ class Extractor: sorted(set([dp["description"] for dp in entities] + already_description)) ) already_source_ids = flat_uniq_list(entities, "source_id") - description = self._handle_entity_relation_summary( - entity_name, description - ) - node_data = dict( - entity_type=entity_type, - description=description, - source_id=already_source_ids, - ) - node_data["entity_name"] = entity_name - self._set_entity_(entity_name, node_data) - return node_data + try: + description = self._handle_entity_relation_summary( + entity_name, description + ) + node_data = dict( + entity_type=entity_type, + description=description, + source_id=already_source_ids, + ) + node_data["entity_name"] = entity_name + self._set_entity_(entity_name, node_data) + return node_data + except Exception as e: + return e def _merge_edges( self,