From 6f99bbbb08e91d867e85ce3876c6accd271986c4 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Thu, 23 May 2024 14:31:16 +0800 Subject: [PATCH] add raptor (#899) ### What problem does this PR solve? #882 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/system_app.py | 3 +- api/db/services/document_service.py | 30 +++++++- api/db/services/llm_service.py | 4 + api/db/services/task_service.py | 3 +- rag/llm/chat_model.py | 5 +- rag/raptor.py | 114 ++++++++++++++++++++++++++++ rag/svr/task_executor.py | 109 +++++++++++++++++++------- rag/utils/redis_conn.py | 20 ++--- 8 files changed, 244 insertions(+), 44 deletions(-) create mode 100644 rag/raptor.py diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 933d1a744..be276aca1 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -60,7 +60,8 @@ def status(): st = timer() try: qinfo = REDIS_CONN.health(SVR_QUEUE_NAME) - res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]} + res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), + "pending": qinfo.get("pending", 0)} except Exception as e: res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index ac569563f..672a24e47 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -18,8 +18,10 @@ from datetime import datetime from elasticsearch_dsl import Q from peewee import fn +from api.db.db_utils import bulk_insert_into_db from api.settings import stat_logger -from api.utils import current_timestamp, get_format_time +from api.utils import current_timestamp, get_format_time, get_uuid +from rag.settings import SVR_QUEUE_NAME from rag.utils.es_conn import ELASTICSEARCH from rag.utils.minio_conn import MINIO from rag.nlp import search @@ -30,6 +32,7 @@ from api.db.db_models import Document from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db import StatusEnum +from rag.utils.redis_conn import REDIS_CONN class DocumentService(CommonService): @@ -110,7 +113,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_unfinished_docs(cls): - fields = [cls.model.id, cls.model.process_begin_at] + fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg] docs = cls.model.select(*fields) \ .where( cls.model.status == StatusEnum.VALID.value, @@ -260,7 +263,12 @@ class DocumentService(CommonService): prg = -1 status = TaskStatus.FAIL.value elif finished: - status = TaskStatus.DONE.value + if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0: + queue_raptor_tasks(d) + prg *= 0.98 + msg.append("------ RAPTOR -------") + else: + status = TaskStatus.DONE.value msg = "\n".join(msg) info = { @@ -282,3 +290,19 @@ class DocumentService(CommonService): return len(cls.model.select(cls.model.id).where( cls.model.kb_id == kb_id).dicts()) + +def queue_raptor_tasks(doc): + def new_task(): + nonlocal doc + return { + "id": get_uuid(), + "doc_id": doc["id"], + "from_page": 0, + "to_page": -1, + "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)." + } + + task = new_task() + bulk_insert_into_db(Task, [task], True) + task["type"] = "raptor" + assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." \ No newline at end of file diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 4776544fc..0dd3b8374 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -155,6 +155,10 @@ class LLMBundle(object): tenant_id, llm_type, llm_name, lang=lang) assert self.mdl, "Can't find mole for {}/{}/{}".format( tenant_id, llm_type, llm_name) + self.max_length = 512 + for lm in LLMService.query(llm_name=llm_name): + self.max_length = lm.max_tokens + break def encode(self, texts: list, batch_size=32): emd, used_tokens = self.mdl.encode(texts, batch_size) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 6dcad8265..0bdfd08b2 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -53,6 +53,7 @@ class TaskService(CommonService): Knowledgebase.embd_id, Tenant.img2txt_id, Tenant.asr_id, + Tenant.llm_id, cls.model.update_time] docs = cls.model.select(*fields) \ .join(Document, on=(cls.model.doc_id == Document.id)) \ @@ -159,4 +160,4 @@ def queue_tasks(doc, bucket, name): DocumentService.begin2parse(doc["id"]) for t in tsks: - assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." \ No newline at end of file + assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index e9eb470c2..9bf4e96af 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -57,8 +57,7 @@ class Base(ABC): stream=True, **gen_conf) for resp in response: - if len(resp.choices) == 0:continue - if not resp.choices[0].delta.content:continue + if not resp.choices or not resp.choices[0].delta.content:continue ans += resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": @@ -379,7 +378,7 @@ class VolcEngineChat(Base): ans += resp.choices[0].message.content yield ans if resp.choices[0].finish_reason == "stop": - return resp.usage.total_tokens + yield resp.usage.total_tokens except Exception as e: yield ans + "\n**ERROR**: " + str(e) diff --git a/rag/raptor.py b/rag/raptor.py new file mode 100644 index 000000000..cf6c850c3 --- /dev/null +++ b/rag/raptor.py @@ -0,0 +1,114 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +import traceback +from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait +from threading import Lock +from typing import Tuple +import umap +import numpy as np +from sklearn.mixture import GaussianMixture + +from rag.utils import num_tokens_from_string, truncate + + +class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: + def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1): + self._max_cluster = max_cluster + self._llm_model = llm_model + self._embd_model = embd_model + self._threshold = threshold + self._prompt = prompt + self._max_token = max_token + + def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int): + max_clusters = min(self._max_cluster, len(embeddings)) + n_clusters = np.arange(1, max_clusters) + bics = [] + for n in n_clusters: + gm = GaussianMixture(n_components=n, random_state=random_state) + gm.fit(embeddings) + bics.append(gm.bic(embeddings)) + optimal_clusters = n_clusters[np.argmin(bics)] + return optimal_clusters + + def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None): + layers = [(0, len(chunks))] + start, end = 0, len(chunks) + if len(chunks) <= 1: return + + def summarize(ck_idx, lock): + nonlocal chunks + try: + texts = [chunks[i][0] for i in ck_idx] + len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + cnt = self._llm_model.chat("You're a helpful assistant.", + [{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}], + {"temperature": 0.3, "max_tokens": self._max_token} + ) + cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt) + print("SUM:", cnt) + embds, _ = self._embd_model.encode([cnt]) + with lock: + chunks.append((cnt, embds[0])) + except Exception as e: + print(e, flush=True) + traceback.print_stack(e) + return e + + labels = [] + while end - start > 1: + embeddings = [embd for _, embd in chunks[start: end]] + if len(embeddings) == 2: + summarize([start, start+1], Lock()) + if callback: + callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) + labels.extend([0,0]) + layers.append((end, len(chunks))) + start = end + end = len(chunks) + continue + + n_neighbors = int((len(embeddings) - 1) ** 0.8) + reduced_embeddings = umap.UMAP( + n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine" + ).fit_transform(embeddings) + n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) + if n_clusters == 1: + lbls = [0 for _ in range(len(reduced_embeddings))] + else: + gm = GaussianMixture(n_components=n_clusters, random_state=random_state) + gm.fit(reduced_embeddings) + probs = gm.predict_proba(reduced_embeddings) + lbls = [np.where(prob > self._threshold)[0] for prob in probs] + lock = Lock() + with ThreadPoolExecutor(max_workers=12) as executor: + threads = [] + for c in range(n_clusters): + ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c] + threads.append(executor.submit(summarize, ck_idx, lock)) + wait(threads, return_when=ALL_COMPLETED) + print([t.result() for t in threads]) + + assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) + labels.extend(lbls) + layers.append((end, len(chunks))) + if callback: + callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) + start = end + end = len(chunks) + diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 413805a8f..a405e5ca2 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -26,20 +26,22 @@ import traceback from functools import partial from api.db.services.file2document_service import File2DocumentService +from api.settings import retrievaler +from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.utils.minio_conn import MINIO from api.db.db_models import close_connection from rag.settings import database_logger, SVR_QUEUE_NAME from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from multiprocessing import Pool import numpy as np -from elasticsearch_dsl import Q +from elasticsearch_dsl import Q, Search from multiprocessing.context import TimeoutError from api.db.services.task_service import TaskService from rag.utils.es_conn import ELASTICSEARCH from timeit import default_timer as timer -from rag.utils import rmSpace, findMaxTm +from rag.utils import rmSpace, findMaxTm, num_tokens_from_string -from rag.nlp import search +from rag.nlp import search, rag_tokenizer from io import BytesIO import pandas as pd @@ -114,6 +116,8 @@ def collect(): tasks = TaskService.get_tasks(msg["id"]) assert tasks, "{} empty task!".format(msg["id"]) tasks = pd.DataFrame(tasks) + if msg.get("type", "") == "raptor": + tasks["task_type"] = "raptor" return tasks @@ -245,6 +249,47 @@ def embedding(docs, mdl, parser_config={}, callback=None): return tk_count +def run_raptor(row, chat_mdl, embd_mdl, callback=None): + vts, _ = embd_mdl.encode(["ok"]) + vctr_nm = "q_%d_vec"%len(vts[0]) + chunks = [] + for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]): + chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) + + raptor = Raptor( + row["parser_config"]["raptor"].get("max_cluster", 64), + chat_mdl, + embd_mdl, + row["parser_config"]["raptor"]["prompt"], + row["parser_config"]["raptor"]["max_token"], + row["parser_config"]["raptor"]["threshold"] + ) + original_length = len(chunks) + raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) + doc = { + "doc_id": row["doc_id"], + "kb_id": [str(row["kb_id"])], + "docnm_kwd": row["name"], + "title_tks": rag_tokenizer.tokenize(row["name"]) + } + res = [] + tk_count = 0 + for content, vctr in chunks[original_length:]: + d = copy.deepcopy(doc) + md5 = hashlib.md5() + md5.update((content + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.datetime.now().timestamp() + d[vctr_nm] = vctr.tolist() + d["content_with_weight"] = content + d["content_ltks"] = rag_tokenizer.tokenize(content) + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + res.append(d) + tk_count += num_tokens_from_string(content) + return res, tk_count + + def main(): rows = collect() if len(rows) == 0: @@ -259,35 +304,45 @@ def main(): cron_logger.error(str(e)) continue - st = timer() - cks = build(r) - cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) - if cks is None: - continue - if not cks: - callback(1., "No chunk! Done!") - continue - # TODO: exception handler - ## set_progress(r["did"], -1, "ERROR: ") - callback( - msg="Finished slicing files(%d). Start to embedding the content." % - len(cks)) - st = timer() - try: - tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) - except Exception as e: - callback(-1, "Embedding error:{}".format(str(e))) - cron_logger.error(str(e)) - tk_count = 0 - cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) + if r.get("task_type", "") == "raptor": + try: + chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"]) + cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback) + except Exception as e: + callback(-1, msg=str(e)) + cron_logger.error(str(e)) + continue + else: + st = timer() + cks = build(r) + cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) + if cks is None: + continue + if not cks: + callback(1., "No chunk! Done!") + continue + # TODO: exception handler + ## set_progress(r["did"], -1, "ERROR: ") + callback( + msg="Finished slicing files(%d). Start to embedding the content." % + len(cks)) + st = timer() + try: + tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) + except Exception as e: + callback(-1, "Embedding error:{}".format(str(e))) + cron_logger.error(str(e)) + tk_count = 0 + cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) + callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) - callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) init_kb(r) chunk_count = len(set([c["_id"] for c in cks])) st = timer() es_r = "" - for b in range(0, len(cks), 32): - es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"])) + es_bulk_size = 16 + for b in range(0, len(cks), es_bulk_size): + es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"])) if b % 128 == 0: callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 1fa0604b8..ef1b73eaa 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -97,15 +97,17 @@ class RedisDB: return False def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool: - try: - payload = {"message": json.dumps(message)} - pipeline = self.REDIS.pipeline() - pipeline.xadd(queue, payload) - pipeline.expire(queue, exp) - pipeline.execute() - return True - except Exception as e: - logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e)) + for _ in range(3): + try: + payload = {"message": json.dumps(message)} + pipeline = self.REDIS.pipeline() + pipeline.xadd(queue, payload) + pipeline.expire(queue, exp) + pipeline.execute() + return True + except Exception as e: + print(e) + logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e)) return False def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload: