diff --git a/rag/raptor.py b/rag/raptor.py index 8c2547adf..886ef1a9c 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -14,6 +14,7 @@ # limitations under the License. # import logging +import os import re from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait from threading import Lock @@ -122,7 +123,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] lock = Lock() - with ThreadPoolExecutor(max_workers=12) as executor: + with ThreadPoolExecutor(max_workers=int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))) as executor: threads = [] for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]