fix raptor bugs (#928)

### What problem does this PR solve?

#922 
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
KevinHuSh 2024-05-27 11:01:20 +08:00 committed by GitHub
parent 55fb96131e
commit 46454362d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 82 additions and 3 deletions

View File

@ -488,3 +488,77 @@ def document_rm():
return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
return get_json_result(data=True)
@manager.route('/completion_aibotk', methods=['POST'])
@validate_request("Authorization", "conversation_id", "word")
def completion_faq():
import base64
req = request.json
token = req["Authorization"]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
if "quote" not in req: req["quote"] = True
msg = []
msg.append({"role": "user", "content": req["word"]})
try:
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["conversation_id"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": ""})
conv.reference.append({"chunks": [], "doc_aggs": []})
def fillin_conv(ans):
nonlocal conv
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
for ans in chat(dia, msg, stream=False, **req):
# answer = ans
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = MINIO.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
except Exception as e:
return server_error_response(e)
break
response = {"code": 200, "msg": "success", "data": data}
return response
except Exception as e:
return server_error_response(e)

View File

@ -229,6 +229,9 @@ def create():
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk_id": chunck_id})
except Exception as e:
return server_error_response(e)

View File

@ -263,7 +263,7 @@ class DocumentService(CommonService):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(" raptor")<0:
queue_raptor_tasks(d)
prg *= 0.98
msg.append("------ RAPTOR -------")

View File

@ -1,6 +1,5 @@
import copy
import re
import numpy as np
import cv2
from shapely.geometry import Polygon

View File

@ -359,7 +359,6 @@ class VolcEngineChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
try:
req = {
"parameters": {
@ -380,6 +379,7 @@ class VolcEngineChat(Base):
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count

View File

@ -95,6 +95,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lock = Lock()
with ThreadPoolExecutor(max_workers=12) as executor:
threads = []

View File

@ -134,4 +134,5 @@ yarl==1.9.4
zhipuai==2.0.1
BCEmbedding
loguru==0.7.2
umap-learn
fasttext==0.9.2

View File

@ -123,3 +123,4 @@ loguru==0.7.2
ollama==0.1.8
redis==5.0.4
fasttext==0.9.2
umap-learn