mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-23 06:30:00 +08:00

### What problem does this PR solve? Fix bugs in agent api and update api document ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
230 lines
8.7 KiB
Python
230 lines
8.7 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from uuid import uuid4
|
|
from api.db import StatusEnum
|
|
from api.db.db_models import Conversation, DB
|
|
from api.db.services.api_service import API4ConversationService
|
|
from api.db.services.common_service import CommonService
|
|
from api.db.services.dialog_service import DialogService, chat
|
|
from api.utils import get_uuid
|
|
import json
|
|
|
|
|
|
class ConversationService(CommonService):
|
|
model = Conversation
|
|
|
|
@classmethod
|
|
@DB.connection_context()
|
|
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name):
|
|
sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id)
|
|
if id:
|
|
sessions = sessions.where(cls.model.id == id)
|
|
if name:
|
|
sessions = sessions.where(cls.model.name == name)
|
|
if desc:
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
|
else:
|
|
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
|
|
|
sessions = sessions.paginate(page_number, items_per_page)
|
|
|
|
return list(sessions.dicts())
|
|
|
|
|
|
def structure_answer(conv, ans, message_id, session_id):
|
|
reference = ans["reference"]
|
|
if not isinstance(reference, dict):
|
|
reference = {}
|
|
ans["reference"] = {}
|
|
|
|
def get_value(d, k1, k2):
|
|
return d.get(k1, d.get(k2))
|
|
chunk_list = [{
|
|
"id": get_value(chunk, "chunk_id", "id"),
|
|
"content": get_value(chunk, "content", "content_with_weight"),
|
|
"document_id": get_value(chunk, "doc_id", "document_id"),
|
|
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
|
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
|
"image_id": get_value(chunk, "image_id", "img_id"),
|
|
"positions": get_value(chunk, "positions", "position_int"),
|
|
} for chunk in reference.get("chunks", [])]
|
|
|
|
reference["chunks"] = chunk_list
|
|
ans["id"] = message_id
|
|
ans["session_id"] = session_id
|
|
|
|
if not conv:
|
|
return ans
|
|
|
|
if not conv.message:
|
|
conv.message = []
|
|
if not conv.message or conv.message[-1].get("role", "") != "assistant":
|
|
conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id})
|
|
else:
|
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
|
if conv.reference:
|
|
conv.reference[-1] = reference
|
|
return ans
|
|
|
|
|
|
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
|
assert name, "`name` can not be empty."
|
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
|
assert dia, "You do not own the chat."
|
|
|
|
if not session_id:
|
|
session_id = get_uuid()
|
|
conv = {
|
|
"id":session_id ,
|
|
"dialog_id": chat_id,
|
|
"name": name,
|
|
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}]
|
|
}
|
|
ConversationService.save(**conv)
|
|
yield "data:" + json.dumps({"code": 0, "message": "",
|
|
"data": {
|
|
"answer": conv["message"][0]["content"],
|
|
"reference": {},
|
|
"audio_binary": None,
|
|
"id": None,
|
|
"session_id": session_id
|
|
}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
|
return
|
|
|
|
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
|
if not conv:
|
|
raise LookupError("Session does not exist")
|
|
|
|
conv = conv[0]
|
|
msg = []
|
|
question = {
|
|
"content": question,
|
|
"role": "user",
|
|
"id": str(uuid4())
|
|
}
|
|
conv.message.append(question)
|
|
for m in conv.message:
|
|
if m["role"] == "system":
|
|
continue
|
|
if m["role"] == "assistant" and not msg:
|
|
continue
|
|
msg.append(m)
|
|
message_id = msg[-1].get("id")
|
|
e, dia = DialogService.get_by_id(conv.dialog_id)
|
|
|
|
if not conv.reference:
|
|
conv.reference = []
|
|
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
|
|
|
if stream:
|
|
try:
|
|
for ans in chat(dia, msg, True, **kwargs):
|
|
ans = structure_answer(conv, ans, message_id, session_id)
|
|
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
|
except Exception as e:
|
|
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
|
|
|
else:
|
|
answer = None
|
|
for ans in chat(dia, msg, False, **kwargs):
|
|
answer = structure_answer(conv, ans, message_id, session_id)
|
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
|
break
|
|
yield answer
|
|
|
|
|
|
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
|
e, dia = DialogService.get_by_id(dialog_id)
|
|
assert e, "Dialog not found"
|
|
if not session_id:
|
|
session_id = get_uuid()
|
|
conv = {
|
|
"id": session_id,
|
|
"dialog_id": dialog_id,
|
|
"user_id": kwargs.get("user_id", ""),
|
|
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
|
}
|
|
API4ConversationService.save(**conv)
|
|
yield "data:" + json.dumps({"code": 0, "message": "",
|
|
"data": {
|
|
"answer": conv["message"][0]["content"],
|
|
"reference": {},
|
|
"audio_binary": None,
|
|
"id": None,
|
|
"session_id": session_id
|
|
}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
|
return
|
|
else:
|
|
session_id = session_id
|
|
e, conv = API4ConversationService.get_by_id(session_id)
|
|
assert e, "Session not found!"
|
|
|
|
if not conv.message:
|
|
conv.message = []
|
|
messages = conv.message
|
|
question = {
|
|
"role": "user",
|
|
"content": question,
|
|
"id": str(uuid4())
|
|
}
|
|
messages.append(question)
|
|
|
|
msg = []
|
|
for m in messages:
|
|
if m["role"] == "system":
|
|
continue
|
|
if m["role"] == "assistant" and not msg:
|
|
continue
|
|
msg.append(m)
|
|
if not msg[-1].get("id"):
|
|
msg[-1]["id"] = get_uuid()
|
|
message_id = msg[-1]["id"]
|
|
|
|
if not conv.reference:
|
|
conv.reference = []
|
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
|
|
|
if stream:
|
|
try:
|
|
for ans in chat(dia, msg, True, **kwargs):
|
|
ans = structure_answer(conv, ans, message_id, session_id)
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
|
ensure_ascii=False) + "\n\n"
|
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
|
except Exception as e:
|
|
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
|
ensure_ascii=False) + "\n\n"
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
|
|
|
else:
|
|
answer = None
|
|
for ans in chat(dia, msg, False, **kwargs):
|
|
answer = structure_answer(conv, ans, message_id, session_id)
|
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
|
break
|
|
yield answer
|
|
|