mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-07-24 22:24:27 +08:00
add raptor (#899)
### What problem does this PR solve? #882 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
3bbdf3b770
commit
6f99bbbb08
@ -60,7 +60,8 @@ def status():
|
||||
st = timer()
|
||||
try:
|
||||
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
||||
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
|
||||
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.),
|
||||
"pending": qinfo.get("pending", 0)}
|
||||
except Exception as e:
|
||||
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
||||
|
||||
|
@ -18,8 +18,10 @@ from datetime import datetime
|
||||
from elasticsearch_dsl import Q
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.settings import stat_logger
|
||||
from api.utils import current_timestamp, get_format_time
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from rag.nlp import search
|
||||
@ -30,6 +32,7 @@ from api.db.db_models import Document
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db import StatusEnum
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class DocumentService(CommonService):
|
||||
@ -110,7 +113,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_unfinished_docs(cls):
|
||||
fields = [cls.model.id, cls.model.process_begin_at]
|
||||
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg]
|
||||
docs = cls.model.select(*fields) \
|
||||
.where(
|
||||
cls.model.status == StatusEnum.VALID.value,
|
||||
@ -260,7 +263,12 @@ class DocumentService(CommonService):
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
status = TaskStatus.DONE.value
|
||||
if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0:
|
||||
queue_raptor_tasks(d)
|
||||
prg *= 0.98
|
||||
msg.append("------ RAPTOR -------")
|
||||
else:
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(msg)
|
||||
info = {
|
||||
@ -282,3 +290,19 @@ class DocumentService(CommonService):
|
||||
return len(cls.model.select(cls.model.id).where(
|
||||
cls.model.kb_id == kb_id).dicts())
|
||||
|
||||
|
||||
def queue_raptor_tasks(doc):
|
||||
def new_task():
|
||||
nonlocal doc
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"from_page": 0,
|
||||
"to_page": -1,
|
||||
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)."
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
task["type"] = "raptor"
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
@ -155,6 +155,10 @@ class LLMBundle(object):
|
||||
tenant_id, llm_type, llm_name, lang=lang)
|
||||
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
||||
tenant_id, llm_type, llm_name)
|
||||
self.max_length = 512
|
||||
for lm in LLMService.query(llm_name=llm_name):
|
||||
self.max_length = lm.max_tokens
|
||||
break
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
||||
|
@ -53,6 +53,7 @@ class TaskService(CommonService):
|
||||
Knowledgebase.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
Tenant.llm_id,
|
||||
cls.model.update_time]
|
||||
docs = cls.model.select(*fields) \
|
||||
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
||||
@ -159,4 +160,4 @@ def queue_tasks(doc, bucket, name):
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
||||
for t in tsks:
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
|
||||
|
@ -57,8 +57,7 @@ class Base(ABC):
|
||||
stream=True,
|
||||
**gen_conf)
|
||||
for resp in response:
|
||||
if len(resp.choices) == 0:continue
|
||||
if not resp.choices[0].delta.content:continue
|
||||
if not resp.choices or not resp.choices[0].delta.content:continue
|
||||
ans += resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
@ -379,7 +378,7 @@ class VolcEngineChat(Base):
|
||||
ans += resp.choices[0].message.content
|
||||
yield ans
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
return resp.usage.total_tokens
|
||||
yield resp.usage.total_tokens
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
114
rag/raptor.py
Normal file
114
rag/raptor.py
Normal file
@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
|
||||
from threading import Lock
|
||||
from typing import Tuple
|
||||
import umap
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1):
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
self._embd_model = embd_model
|
||||
self._threshold = threshold
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int):
|
||||
max_clusters = min(self._max_cluster, len(embeddings))
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
gm = GaussianMixture(n_components=n, random_state=random_state)
|
||||
gm.fit(embeddings)
|
||||
bics.append(gm.bic(embeddings))
|
||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||
return optimal_clusters
|
||||
|
||||
def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None):
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
if len(chunks) <= 1: return
|
||||
|
||||
def summarize(ck_idx, lock):
|
||||
nonlocal chunks
|
||||
try:
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
cnt = self._llm_model.chat("You're a helpful assistant.",
|
||||
[{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}],
|
||||
{"temperature": 0.3, "max_tokens": self._max_token}
|
||||
)
|
||||
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt)
|
||||
print("SUM:", cnt)
|
||||
embds, _ = self._embd_model.encode([cnt])
|
||||
with lock:
|
||||
chunks.append((cnt, embds[0]))
|
||||
except Exception as e:
|
||||
print(e, flush=True)
|
||||
traceback.print_stack(e)
|
||||
return e
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
embeddings = [embd for _, embd in chunks[start: end]]
|
||||
if len(embeddings) == 2:
|
||||
summarize([start, start+1], Lock())
|
||||
if callback:
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
||||
labels.extend([0,0])
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
continue
|
||||
|
||||
n_neighbors = int((len(embeddings) - 1) ** 0.8)
|
||||
reduced_embeddings = umap.UMAP(
|
||||
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine"
|
||||
).fit_transform(embeddings)
|
||||
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
|
||||
if n_clusters == 1:
|
||||
lbls = [0 for _ in range(len(reduced_embeddings))]
|
||||
else:
|
||||
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
||||
gm.fit(reduced_embeddings)
|
||||
probs = gm.predict_proba(reduced_embeddings)
|
||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||
lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=12) as executor:
|
||||
threads = []
|
||||
for c in range(n_clusters):
|
||||
ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c]
|
||||
threads.append(executor.submit(summarize, ck_idx, lock))
|
||||
wait(threads, return_when=ALL_COMPLETED)
|
||||
print([t.result() for t in threads])
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
labels.extend(lbls)
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
@ -26,20 +26,22 @@ import traceback
|
||||
from functools import partial
|
||||
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from api.db.db_models import close_connection
|
||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from multiprocessing import Pool
|
||||
import numpy as np
|
||||
from elasticsearch_dsl import Q
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from multiprocessing.context import TimeoutError
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from timeit import default_timer as timer
|
||||
from rag.utils import rmSpace, findMaxTm
|
||||
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string
|
||||
|
||||
from rag.nlp import search
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
|
||||
@ -114,6 +116,8 @@ def collect():
|
||||
tasks = TaskService.get_tasks(msg["id"])
|
||||
assert tasks, "{} empty task!".format(msg["id"])
|
||||
tasks = pd.DataFrame(tasks)
|
||||
if msg.get("type", "") == "raptor":
|
||||
tasks["task_type"] = "raptor"
|
||||
return tasks
|
||||
|
||||
|
||||
@ -245,6 +249,47 @@ def embedding(docs, mdl, parser_config={}, callback=None):
|
||||
return tk_count
|
||||
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vts, _ = embd_mdl.encode(["ok"])
|
||||
vctr_nm = "q_%d_vec"%len(vts[0])
|
||||
chunks = []
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
row["parser_config"]["raptor"].get("max_cluster", 64),
|
||||
chat_mdl,
|
||||
embd_mdl,
|
||||
row["parser_config"]["raptor"]["prompt"],
|
||||
row["parser_config"]["raptor"]["max_token"],
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
)
|
||||
original_length = len(chunks)
|
||||
raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": row["name"],
|
||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||
}
|
||||
res = []
|
||||
tk_count = 0
|
||||
for content, vctr in chunks[original_length:]:
|
||||
d = copy.deepcopy(doc)
|
||||
md5 = hashlib.md5()
|
||||
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
d[vctr_nm] = vctr.tolist()
|
||||
d["content_with_weight"] = content
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(content)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
res.append(d)
|
||||
tk_count += num_tokens_from_string(content)
|
||||
return res, tk_count
|
||||
|
||||
|
||||
def main():
|
||||
rows = collect()
|
||||
if len(rows) == 0:
|
||||
@ -259,35 +304,45 @@ def main():
|
||||
cron_logger.error(str(e))
|
||||
continue
|
||||
|
||||
st = timer()
|
||||
cks = build(r)
|
||||
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
||||
if cks is None:
|
||||
continue
|
||||
if not cks:
|
||||
callback(1., "No chunk! Done!")
|
||||
continue
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
callback(
|
||||
msg="Finished slicing files(%d). Start to embedding the content." %
|
||||
len(cks))
|
||||
st = timer()
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||
except Exception as e:
|
||||
callback(-1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
tk_count = 0
|
||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if r.get("task_type", "") == "raptor":
|
||||
try:
|
||||
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
||||
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
||||
except Exception as e:
|
||||
callback(-1, msg=str(e))
|
||||
cron_logger.error(str(e))
|
||||
continue
|
||||
else:
|
||||
st = timer()
|
||||
cks = build(r)
|
||||
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
||||
if cks is None:
|
||||
continue
|
||||
if not cks:
|
||||
callback(1., "No chunk! Done!")
|
||||
continue
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
callback(
|
||||
msg="Finished slicing files(%d). Start to embedding the content." %
|
||||
len(cks))
|
||||
st = timer()
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||
except Exception as e:
|
||||
callback(-1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
tk_count = 0
|
||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||
|
||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||
init_kb(r)
|
||||
chunk_count = len(set([c["_id"] for c in cks]))
|
||||
st = timer()
|
||||
es_r = ""
|
||||
for b in range(0, len(cks), 32):
|
||||
es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
|
||||
es_bulk_size = 16
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
|
||||
if b % 128 == 0:
|
||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||
|
||||
|
@ -97,15 +97,17 @@ class RedisDB:
|
||||
return False
|
||||
|
||||
def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool:
|
||||
try:
|
||||
payload = {"message": json.dumps(message)}
|
||||
pipeline = self.REDIS.pipeline()
|
||||
pipeline.xadd(queue, payload)
|
||||
pipeline.expire(queue, exp)
|
||||
pipeline.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e))
|
||||
for _ in range(3):
|
||||
try:
|
||||
payload = {"message": json.dumps(message)}
|
||||
pipeline = self.REDIS.pipeline()
|
||||
pipeline.xadd(queue, payload)
|
||||
pipeline.expire(queue, exp)
|
||||
pipeline.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e))
|
||||
return False
|
||||
|
||||
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
|
||||
|
Loading…
x
Reference in New Issue
Block a user