From 27cd765d6fe522c789f1833437593a125a74d6d9 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 29 Nov 2024 11:55:41 +0800 Subject: [PATCH] Fix raptor issue (#3737) ### What problem does this PR solve? #3732 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/raptor.py | 30 +++++++++++++++++------------- rag/svr/task_executor.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/rag/raptor.py b/rag/raptor.py index 5974e371d..51f1ad117 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -33,7 +33,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._prompt = prompt self._max_token = max_token - def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int): + def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int): max_clusters = min(self._max_cluster, len(embeddings)) n_clusters = np.arange(1, max_clusters) bics = [] @@ -44,7 +44,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: optimal_clusters = n_clusters[np.argmin(bics)] return optimal_clusters - def __call__(self, chunks: tuple[str, np.ndarray], random_state, callback=None): + def __call__(self, chunks, random_state, callback=None): layers = [(0, len(chunks))] start, end = 0, len(chunks) if len(chunks) <= 1: return @@ -54,13 +54,15 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: nonlocal chunks try: texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts)) + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) cnt = self._llm_model.chat("You're a helpful assistant.", - [{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}], - {"temperature": 0.3, "max_tokens": self._max_token} - ) - cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt) + [{"role": "user", + "content": self._prompt.format(cluster_content=cluster_content)}], + {"temperature": 0.3, "max_tokens": self._max_token} + ) + cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", + cnt) logging.debug(f"SUM: {cnt}") embds, _ = self._embd_model.encode([cnt]) with lock: @@ -74,10 +76,10 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: while end - start > 1: embeddings = [embd for _, embd in chunks[start: end]] if len(embeddings) == 2: - summarize([start, start+1], Lock()) + summarize([start, start + 1], Lock()) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) - labels.extend([0,0]) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + labels.extend([0, 0]) layers.append((end, len(chunks))) start = end end = len(chunks) @@ -85,7 +87,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_neighbors = int((len(embeddings) - 1) ** 0.8) reduced_embeddings = umap.UMAP( - n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine" + n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine" ).fit_transform(embeddings) n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) if n_clusters == 1: @@ -100,7 +102,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: with ThreadPoolExecutor(max_workers=12) as executor: threads = [] for c in range(n_clusters): - ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c] + ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] threads.append(executor.submit(summarize, ck_idx, lock)) wait(threads, return_when=ALL_COMPLETED) logging.debug(str([t.result() for t in threads])) @@ -109,7 +111,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: labels.extend(lbls) layers.append((end, len(chunks))) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) start = end end = len(chunks) + return chunks + diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d7be6fa81..cc69bdaa6 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -344,7 +344,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): row["parser_config"]["raptor"]["threshold"] ) original_length = len(chunks) - raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) + chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) doc = { "doc_id": row["doc_id"], "kb_id": [str(row["kb_id"])],