mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 10:49:03 +08:00
Fix open AI compatible rerank issue. (#3866)
### What problem does this PR solve? #3700 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
84afb4259c
commit
78601ee1bd
@ -35,7 +35,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
|
|
||||||
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
||||||
@token_required
|
@token_required
|
||||||
def create(tenant_id,chat_id):
|
def create(tenant_id, chat_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
req["dialog_id"] = chat_id
|
req["dialog_id"] = chat_id
|
||||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||||
@ -79,7 +79,7 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
"user_id": tenant_id,
|
"user_id": tenant_id,
|
||||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||||
"source": "agent",
|
"source": "agent",
|
||||||
"dsl":json.loads(cvs.dsl)
|
"dsl": json.loads(cvs.dsl)
|
||||||
}
|
}
|
||||||
API4ConversationService.save(**conv)
|
API4ConversationService.save(**conv)
|
||||||
conv["agent_id"] = conv.pop("dialog_id")
|
conv["agent_id"] = conv.pop("dialog_id")
|
||||||
@ -88,11 +88,11 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
|
|
||||||
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
|
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
|
||||||
@token_required
|
@token_required
|
||||||
def update(tenant_id,chat_id,session_id):
|
def update(tenant_id, chat_id, session_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
req["dialog_id"] = chat_id
|
req["dialog_id"] = chat_id
|
||||||
conv_id = session_id
|
conv_id = session_id
|
||||||
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
|
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||||
if not conv:
|
if not conv:
|
||||||
return get_error_data_result(message="Session does not exist")
|
return get_error_data_result(message="Session does not exist")
|
||||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
@ -111,7 +111,7 @@ def update(tenant_id,chat_id,session_id):
|
|||||||
@manager.route('/chats/<chat_id>/completions', methods=['POST'])
|
@manager.route('/chats/<chat_id>/completions', methods=['POST'])
|
||||||
@token_required
|
@token_required
|
||||||
def completion(tenant_id, chat_id):
|
def completion(tenant_id, chat_id):
|
||||||
dia= DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
if not dia:
|
if not dia:
|
||||||
return get_error_data_result(message="You do not own the chat")
|
return get_error_data_result(message="You do not own the chat")
|
||||||
req = request.json
|
req = request.json
|
||||||
@ -126,12 +126,12 @@ def completion(tenant_id, chat_id):
|
|||||||
return get_error_data_result(message="`name` can not be empty.")
|
return get_error_data_result(message="`name` can not be empty.")
|
||||||
ConversationService.save(**conv)
|
ConversationService.save(**conv)
|
||||||
e, conv = ConversationService.get_by_id(conv["id"])
|
e, conv = ConversationService.get_by_id(conv["id"])
|
||||||
session_id=conv.id
|
session_id = conv.id
|
||||||
else:
|
else:
|
||||||
session_id = req.get("session_id")
|
session_id = req.get("session_id")
|
||||||
if not req.get("question"):
|
if not req.get("question"):
|
||||||
return get_error_data_result(message="Please input your question.")
|
return get_error_data_result(message="Please input your question.")
|
||||||
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
|
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
||||||
if not conv:
|
if not conv:
|
||||||
return get_error_data_result(message="Session does not exist")
|
return get_error_data_result(message="Session does not exist")
|
||||||
conv = conv[0]
|
conv = conv[0]
|
||||||
@ -183,7 +183,7 @@ def completion(tenant_id, chat_id):
|
|||||||
chunk_list.append(new_chunk)
|
chunk_list.append(new_chunk)
|
||||||
reference["chunks"] = chunk_list
|
reference["chunks"] = chunk_list
|
||||||
ans["id"] = message_id
|
ans["id"] = message_id
|
||||||
ans["session_id"]=session_id
|
ans["session_id"] = session_id
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
@ -194,7 +194,7 @@ def completion(tenant_id, chat_id):
|
|||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||||
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
@ -375,10 +375,9 @@ def agent_completion(tenant_id, agent_id):
|
|||||||
return get_result(data=result)
|
return get_result(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
||||||
@token_required
|
@token_required
|
||||||
def list_session(chat_id,tenant_id):
|
def list_session(tenant_id, chat_id):
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||||
id = request.args.get("id")
|
id = request.args.get("id")
|
||||||
@ -390,7 +389,7 @@ def list_session(chat_id,tenant_id):
|
|||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
|
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
|
||||||
if not convs:
|
if not convs:
|
||||||
return get_result(data=[])
|
return get_result(data=[])
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
@ -429,13 +428,14 @@ def list_session(chat_id,tenant_id):
|
|||||||
del conv["reference"]
|
del conv["reference"]
|
||||||
return get_result(data=convs)
|
return get_result(data=convs)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents/<agent_id>/sessions', methods=['GET'])
|
@manager.route('/agents/<agent_id>/sessions', methods=['GET'])
|
||||||
@token_required
|
@token_required
|
||||||
def list_agent_session(agent_id,tenant_id):
|
def list_agent_session(tenant_id, agent_id):
|
||||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||||
id = request.args.get("id")
|
id = request.args.get("id")
|
||||||
if not API4ConversationService.query(id=id,user_id=tenant_id):
|
if not API4ConversationService.query(id=id, user_id=tenant_id):
|
||||||
return get_error_data_result(f"You don't own the session {id}")
|
return get_error_data_result(f"You don't own the session {id}")
|
||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 1))
|
||||||
items_per_page = int(request.args.get("page_size", 30))
|
items_per_page = int(request.args.get("page_size", 30))
|
||||||
@ -444,7 +444,7 @@ def list_agent_session(agent_id,tenant_id):
|
|||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
convs = API4ConversationService.get_list(agent_id,tenant_id,page_number,items_per_page,orderby,desc,id)
|
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id)
|
||||||
if not convs:
|
if not convs:
|
||||||
return get_result(data=[])
|
return get_result(data=[])
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
@ -486,7 +486,7 @@ def list_agent_session(agent_id,tenant_id):
|
|||||||
|
|
||||||
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
|
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
|
||||||
@token_required
|
@token_required
|
||||||
def delete(tenant_id,chat_id):
|
def delete(tenant_id, chat_id):
|
||||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
return get_error_data_result(message="You don't own the chat")
|
return get_error_data_result(message="You don't own the chat")
|
||||||
req = request.json
|
req = request.json
|
||||||
@ -494,21 +494,22 @@ def delete(tenant_id,chat_id):
|
|||||||
if not req:
|
if not req:
|
||||||
ids = None
|
ids = None
|
||||||
else:
|
else:
|
||||||
ids=req.get("ids")
|
ids = req.get("ids")
|
||||||
|
|
||||||
if not ids:
|
if not ids:
|
||||||
conv_list = []
|
conv_list = []
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
conv_list.append(conv.id)
|
conv_list.append(conv.id)
|
||||||
else:
|
else:
|
||||||
conv_list=ids
|
conv_list = ids
|
||||||
for id in conv_list:
|
for id in conv_list:
|
||||||
conv = ConversationService.query(id=id,dialog_id=chat_id)
|
conv = ConversationService.query(id=id, dialog_id=chat_id)
|
||||||
if not conv:
|
if not conv:
|
||||||
return get_error_data_result(message="The chat doesn't own the session")
|
return get_error_data_result(message="The chat doesn't own the session")
|
||||||
ConversationService.delete_by_id(id)
|
ConversationService.delete_by_id(id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/sessions/ask', methods=['POST'])
|
@manager.route('/sessions/ask', methods=['POST'])
|
||||||
@token_required
|
@token_required
|
||||||
def ask_about(tenant_id):
|
def ask_about(tenant_id):
|
||||||
@ -517,17 +518,18 @@ def ask_about(tenant_id):
|
|||||||
return get_error_data_result("`question` is required.")
|
return get_error_data_result("`question` is required.")
|
||||||
if not req.get("dataset_ids"):
|
if not req.get("dataset_ids"):
|
||||||
return get_error_data_result("`dataset_ids` is required.")
|
return get_error_data_result("`dataset_ids` is required.")
|
||||||
if not isinstance(req.get("dataset_ids"),list):
|
if not isinstance(req.get("dataset_ids"), list):
|
||||||
return get_error_data_result("`dataset_ids` should be a list.")
|
return get_error_data_result("`dataset_ids` should be a list.")
|
||||||
req["kb_ids"]=req.pop("dataset_ids")
|
req["kb_ids"] = req.pop("dataset_ids")
|
||||||
for kb_id in req["kb_ids"]:
|
for kb_id in req["kb_ids"]:
|
||||||
if not KnowledgebaseService.accessible(kb_id,tenant_id):
|
if not KnowledgebaseService.accessible(kb_id, tenant_id):
|
||||||
return get_error_data_result(f"You don't own the dataset {kb_id}.")
|
return get_error_data_result(f"You don't own the dataset {kb_id}.")
|
||||||
kbs = KnowledgebaseService.query(id=kb_id)
|
kbs = KnowledgebaseService.query(id=kb_id)
|
||||||
kb = kbs[0]
|
kb = kbs[0]
|
||||||
if kb.chunk_num == 0:
|
if kb.chunk_num == 0:
|
||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
|
@ -286,7 +286,7 @@ class OpenAI_APIRerank(Base):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}"
|
"Authorization": f"Bearer {key}"
|
||||||
}
|
}
|
||||||
self.model_name = model_name
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
# noway to config Ragflow , use fix setting
|
# noway to config Ragflow , use fix setting
|
||||||
|
Loading…
x
Reference in New Issue
Block a user