Fix raptor issue (#3737)

### What problem does this PR solve?

#3732

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-11-29 11:55:41 +08:00 committed by GitHub
parent a0c0a957b4
commit 27cd765d6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 14 deletions

View File

@ -44,7 +44,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)] optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters return optimal_clusters
def __call__(self, chunks: tuple[str, np.ndarray], random_state, callback=None): def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
if len(chunks) <= 1: return if len(chunks) <= 1: return
@ -57,10 +57,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
cnt = self._llm_model.chat("You're a helpful assistant.", cnt = self._llm_model.chat("You're a helpful assistant.",
[{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}], [{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token} {"temperature": 0.3, "max_tokens": self._max_token}
) )
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt) cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt)
logging.debug(f"SUM: {cnt}") logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt]) embds, _ = self._embd_model.encode([cnt])
with lock: with lock:
@ -113,3 +115,5 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
start = end start = end
end = len(chunks) end = len(chunks)
return chunks

View File

@ -344,7 +344,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
row["parser_config"]["raptor"]["threshold"] row["parser_config"]["raptor"]["threshold"]
) )
original_length = len(chunks) original_length = len(chunks)
raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) chunks = raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = { doc = {
"doc_id": row["doc_id"], "doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],