Fix bugs in API (#3103)

### What problem does this PR solve?

Fix bugs in API


- [x] Bug Fix (non-breaking change which fixes an issue)

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua 2024-10-30 16:15:42 +08:00 committed by GitHub
parent 86b546f657
commit 18dfa2900c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 124 additions and 57 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
from flask import request
from api.settings import RetCode
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -40,6 +40,10 @@ def create(tenant_id):
kb=kbs[0]
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set(kb.embd_id for kb in kbs))
if embd_count != 1:
return get_result(retmsg='Datasets use different embedding models."',retcode=RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
# llm
llm = req.get("llm")
@ -149,6 +153,8 @@ def update(tenant_id,chat_id):
return get_error_data_result(retmsg='You do not own the chat')
req =request.json
ids = req.get("dataset_ids")
if "show_quotation" in req:
req["do_refer"]=req.pop("show_quotation")
if "dataset_ids" in req:
if not ids:
return get_error_data_result("`datasets` can't be empty")
@ -160,6 +166,12 @@ def update(tenant_id,chat_id):
kb = kbs[0]
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count=list(set(kb.embd_id for kb in kbs))
if embd_count != 1 :
return get_result(
retmsg='Datasets use different embedding models."',
retcode=RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
llm = req.get("llm")
if llm:
@ -225,10 +237,18 @@ def update(tenant_id,chat_id):
@token_required
def delete(tenant_id):
req = request.json
ids = req.get("ids")
if not req:
ids=None
else:
ids=req.get("ids")
if not ids:
return get_error_data_result(retmsg="`ids` are required")
for id in ids:
id_list = []
dias=DialogService.query(tenant_id=tenant_id,status=StatusEnum.VALID.value)
for dia in dias:
id_list.append(dia.id)
else:
id_list=ids
for id in id_list:
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
return get_error_data_result(retmsg=f"You don't own the chat {id}")
temp_dict = {"status": StatusEnum.INVALID.value}
@ -260,7 +280,8 @@ def list_chat(tenant_id):
"quote": "show_quote",
"system": "prompt",
"rerank_id": "rerank_model",
"vector_similarity_weight": "keywords_similarity_weight"}
"vector_similarity_weight": "keywords_similarity_weight",
"do_refer":"show_quotation"}
key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
for res in chats:
for key, value in res["prompt_config"].items():

View File

@ -21,7 +21,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import TenantLLMService,LLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
@ -68,9 +68,12 @@ def create(tenant_id):
"BAAI/bge-small-zh-v1.5","jinaai/jina-embeddings-v2-base-en","jinaai/jina-embeddings-v2-small-en",
"nomic-ai/nomic-embed-text-v1.5","sentence-transformers/all-MiniLM-L6-v2","text-embedding-v2",
"text-embedding-v3","maidalun1020/bce-embedding-base_v1"]
if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))\
and req.get("embedding_model") not in valid_embedding_models:
embd_model=LLMService.query(llm_name=req["embedding_model"],model_type="embedding")
if not embd_model:
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
@ -92,25 +95,32 @@ def create(tenant_id):
@token_required
def delete(tenant_id):
req = request.json
ids = req.get("ids")
if not req:
ids=None
else:
ids=req.get("ids")
if not ids:
return get_error_data_result(
retmsg="ids are required")
for id in ids:
id_list = []
kbs=KnowledgebaseService.query(tenant_id=tenant_id)
for kb in kbs:
id_list.append(kb.id)
else:
id_list=ids
for id in id_list:
kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id)
if not kbs:
return get_error_data_result(retmsg=f"You don't own the dataset {id}")
for doc in DocumentService.query(kb_id=id):
if not DocumentService.remove_document(doc, tenant_id):
for doc in DocumentService.query(kb_id=id):
if not DocumentService.remove_document(doc, tenant_id):
return get_error_data_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(
retmsg="Delete dataset error.(Database error)")
return get_result(retcode=RetCode.SUCCESS)
retmsg="Delete dataset error.(Database error)")
return get_result(retcode=RetCode.SUCCESS)
@manager.route('/datasets/<dataset_id>', methods=['PUT'])
@token_required
@ -139,8 +149,9 @@ def update(tenant_id,dataset_id):
retmsg="Can't change `tenant_id`.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if "parser_config" in req:
print(kb.parser_config,flush=True)
req["parser_config"]=kb.parser_config.update(req["parser_config"])
temp_dict=kb.parser_config
temp_dict.update(req["parser_config"])
req["parser_config"] = temp_dict
if "chunk_count" in req:
if req["chunk_count"] != kb.chunk_num:
return get_error_data_result(
@ -157,7 +168,8 @@ def update(tenant_id,dataset_id):
retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.")
req['parser_id'] = req.pop('chunk_method')
if req['parser_id'] != kb.parser_id:
req["parser_config"] = get_parser_config(chunk_method, parser_config)
if not req.get("parser_config"):
req["parser_config"] = get_parser_config(chunk_method, parser_config)
if "embedding_model" in req:
if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id:
return get_error_data_result(
@ -168,9 +180,12 @@ def update(tenant_id,dataset_id):
"BAAI/bge-small-zh-v1.5","jinaai/jina-embeddings-v2-base-en","jinaai/jina-embeddings-v2-small-en",
"nomic-ai/nomic-embed-text-v1.5","sentence-transformers/all-MiniLM-L6-v2","text-embedding-v2",
"text-embedding-v3","maidalun1020/bce-embedding-base_v1"]
if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))\
and req.get("embedding_model") not in valid_embedding_models:
embd_model=LLMService.query(llm_name=req["embedding_model"],model_type="embedding")
if not embd_model:
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
req['embd_id'] = req.pop('embedding_model')
if "name" in req:
req["name"] = req["name"].strip()

View File

@ -46,6 +46,9 @@ from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.storage_factory import STORAGE_IMPL
import os
MAXIMUM_OF_UPLOADING_FILES = 256
@manager.route('/datasets/<dataset_id>/documents', methods=['POST'])
@token_required
@ -58,11 +61,21 @@ def upload(dataset_id, tenant_id):
if file_obj.filename == '':
return get_result(
retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
# total size
total_size = 0
for file_obj in file_objs:
file_obj.seek(0, os.SEEK_END)
total_size += file_obj.tell()
file_obj.seek(0)
MAX_TOTAL_FILE_SIZE=10*1024*1024
if total_size > MAX_TOTAL_FILE_SIZE:
return get_result(
retmsg=f'Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)',
retcode=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
err, files= FileService.upload_document(kb, file_objs, tenant_id)
if err:
return get_result(
retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
@ -140,6 +153,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not e:
return get_error_data_result(retmsg="Document not found!")
req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config"))
DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1,
doc.process_duation * -1)
@ -210,10 +224,10 @@ def list_docs(dataset_id, tenant_id):
}
renamed_doc = {}
for key, value in doc.items():
if key =="run":
renamed_doc["run"]=run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key =="run":
renamed_doc["run"]=run_mapping.get(value)
renamed_doc_list.append(renamed_doc)
return get_result(data={"total": tol, "docs": renamed_doc_list})
@ -280,14 +294,11 @@ def parse(tenant_id,dataset_id):
doc = DocumentService.query(id=id,kb_id=dataset_id)
if not doc:
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if doc[0].progress != 0.0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
info = {"run": "1", "progress": 0}
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(id, info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
TaskService.filter_delete([Task.doc_id == id])
@ -312,10 +323,8 @@ def stop_parsing(tenant_id,dataset_id):
return get_error_data_result(retmsg=f"You don't own the document {id}.")
if doc[0].progress == 100.0 or doc[0].progress == 0.0:
return get_error_data_result("Can't stop parsing document with progress at 0 or 100")
info = {"run": "2", "progress": 0}
info = {"run": "2", "progress": 0,"chunk_num":0}
DocumentService.update_by_id(id, info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(id)
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
return get_result()
@ -355,10 +364,10 @@ def list_chunks(tenant_id,dataset_id,document_id):
doc=doc.to_dict()
renamed_doc = {}
for key, value in doc.items():
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
origin_chunks = []
sign = 0
@ -398,12 +407,17 @@ def list_chunks(tenant_id,dataset_id,document_id):
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"img_id": "image_id"
"img_id": "image_id",
"available_int":"available"
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
if renamed_chunk["available"] == "0":
renamed_chunk["available"] = False
if renamed_chunk["available"] == "1":
renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk)
return get_result(data=res)
@ -441,7 +455,7 @@ def add_chunk(tenant_id,dataset_id,document_id):
embd_id = DocumentService.get_embd_id(document_id)
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
print(embd_mdl,flush=True)
v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
@ -459,7 +473,7 @@ def add_chunk(tenant_id,dataset_id,document_id):
"kb_id": "dataset_id",
"create_timestamp_flt": "create_timestamp",
"create_time": "create_time",
"document_keyword": "document",
"document_keyword": "document"
}
renamed_chunk = {}
for key, value in d.items():
@ -480,12 +494,18 @@ def rm_chunk(tenant_id,dataset_id,document_id):
return get_error_data_result(retmsg=f"You don't own the document {document_id}.")
doc = doc[0]
req = request.json
if not req.get("chunk_ids"):
return get_error_data_result("`chunk_ids` is required")
query = {
"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
for chunk_id in req.get("chunk_ids"):
if not req:
chunk_ids=None
else:
chunk_ids=req.get("chunk_ids")
if not chunk_ids:
chunk_list=sres.ids
else:
chunk_list=chunk_ids
for chunk_id in chunk_list:
if chunk_id not in sres.ids:
return get_error_data_result(f"Chunk {chunk_id} not found")
if not ELASTICSEARCH.deleteByQuery(

View File

@ -100,7 +100,7 @@ def completion(tenant_id,chat_id):
return get_error_data_result(retmsg="Session does not exist")
conv = conv[0]
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(retmsg="You do not own the session")
return get_error_data_result(retmsg="You do not own the chat")
msg = []
question = {
"content": req.get("question"),
@ -168,9 +168,6 @@ def list(chat_id,tenant_id):
return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
name = request.args.get("name")
session = ConversationService.query(id=id,name=name,dialog_id=chat_id)
if not session:
return get_error_data_result(retmsg="The session doesn't exist")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
@ -183,6 +180,10 @@ def list(chat_id,tenant_id):
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
info.pop("prompt")
conv["chat"] = conv.pop("dialog_id")
if conv["reference"]:
messages = conv["messages"]
@ -218,10 +219,20 @@ def list(chat_id,tenant_id):
def delete(tenant_id,chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(retmsg="You don't own the chat")
ids = request.json.get("ids")
req = request.json
convs = ConversationService.query(dialog_id=chat_id)
if not req:
ids = None
else:
ids=req.get("ids")
if not ids:
return get_error_data_result(retmsg="`ids` is required in deleting operation")
for id in ids:
conv_list = []
for conv in convs:
conv_list.append(conv.id)
else:
conv_list=ids
for id in conv_list:
conv = ConversationService.query(id=id,dialog_id=chat_id)
if not conv:
return get_error_data_result(retmsg="The chat doesn't own the session")

View File

@ -344,7 +344,7 @@ def get_parser_config(chunk_method,parser_config):
return parser_config
if not chunk_method:
chunk_method = "naive"
key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"user_raptor": False}},
key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}},
"qa":{"raptor":{"use_raptor":False}},
"resume":None,
"manual":{"raptor":{"use_raptor":False}},

View File

@ -68,7 +68,7 @@ class Chat(Base):
return result_list
raise Exception(res["message"])
def delete_sessions(self,ids):
def delete_sessions(self,ids:List[str]=None):
res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids})
res = res.json()
if res.get("code") != 0:

View File

@ -64,7 +64,7 @@ class RAGFlow:
return DataSet(self, res["data"])
raise Exception(res["message"])
def delete_datasets(self, ids: List[str]):
def delete_datasets(self, ids: List[str] = None):
res = self.delete("/datasets",{"ids": ids})
res=res.json()
if res.get("code") != 0:
@ -135,9 +135,9 @@ class RAGFlow:
return Chat(self, res["data"])
raise Exception(res["message"])
def delete_chats(self,ids: List[str] = None,names: List[str] = None ) -> bool:
def delete_chats(self,ids: List[str] = None) -> bool:
res = self.delete('/chats',
{"ids":ids, "names":names})
{"ids":ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])