SDK for session (#2312)

### What problem does this PR solve?

SDK for session
#1102 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
LiuHua 2024-09-09 17:18:08 +08:00 committed by GitHub
parent ceae4df889
commit 336a639164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 325 additions and 35 deletions

View File

@ -16,9 +16,10 @@
from flask import request
from api.db import StatusEnum
from api.db.db_models import TenantLLM
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
@ -30,7 +31,6 @@ from api.utils.api_utils import get_json_result
@token_required
def save(tenant_id):
req = request.json
id = req.get("id")
# dataset
if req.get("knowledgebases") == []:
return get_data_error_result(retmsg="knowledgebases can not be empty list")
@ -41,8 +41,8 @@ def save(tenant_id):
return get_data_error_result(retmsg="knowledgebase needs id")
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
return get_data_error_result(retmsg="you do not own the knowledgebase")
if not DocumentService.query(kb_id=kb["id"]):
return get_data_error_result(retmsg="There is a invalid knowledgebase")
# if not DocumentService.query(kb_id=kb["id"]):
# return get_data_error_result(retmsg="There is a invalid knowledgebase")
kb_list.append(kb["id"])
req["kb_ids"] = kb_list
# llm
@ -72,10 +72,10 @@ def save(tenant_id):
req[key] = prompt.pop(key)
req["prompt_config"] = req.pop("prompt")
# create
if not id:
if "id" not in req:
# dataset
if not kb_list:
return get_data_error_result(retmsg="knowledgebase is required!")
return get_data_error_result(retmsg="knowledgebases are required!")
# init
req["id"] = get_uuid()
req["description"] = req.get("description", "A helpful Assistant")
@ -83,7 +83,11 @@ def save(tenant_id):
req["top_n"] = req.get("top_n", 6)
req["top_k"] = req.get("top_k", 1024)
req["rerank_id"] = req.get("rerank_id", "")
req["llm_id"] = req.get("llm_id", tenant.llm_id)
if req.get("llm_id"):
if not TenantLLMService.query(llm_name=req["llm_id"]):
return get_data_error_result(retmsg="the model_name does not exist.")
else:
req["llm_id"] = tenant.llm_id
if not req.get("name"):
return get_data_error_result(retmsg="name is required.")
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -149,14 +153,20 @@ def save(tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
# prompt
if not req["id"]:
return get_data_error_result(retmsg="id can not be empty")
e, res = DialogService.get_by_id(req["id"])
res = res.to_json()
if "llm_id" in req:
if not TenantLLMService.query(llm_name=req["llm_id"]):
return get_data_error_result(retmsg="the model_name does not exist.")
if "name" in req:
if not req.get("name"):
return get_data_error_result(retmsg="name is not empty.")
if req["name"].lower() != res["name"].lower() \
and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.")
and len(
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.")
if "prompt_config" in req:
res["prompt_config"].update(req["prompt_config"])
for p in res["prompt_config"]["parameters"]:
@ -213,7 +223,8 @@ def get(tenant_id):
name = req["name"]
ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not ass:
return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=False, retmsg='You do not own the assistant.',
retcode=RetCode.OPERATING_ERROR)
res = ass[0].to_json()
else:
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")

View File

@ -88,6 +88,9 @@ def save(tenant_id):
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
if not req["id"]:
return get_data_error_result(
retmsg="id can not be empty.")
e, kb = KnowledgebaseService.get_by_id(req["id"])
if "chunk_count" in req:
@ -108,6 +111,7 @@ def save(tenant_id):
retmsg="If chunk count is not 0, parse method is not changable.")
req['parser_id'] = req.pop('parse_method')
if "name" in req:
req["name"] = req["name"].strip()
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:

168
api/apps/sdk/session.py Normal file
View File

@ -0,0 +1,168 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from copy import deepcopy
from uuid import uuid4
from flask import request, Response
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result, token_required
@manager.route('/save', methods=['POST'])
@token_required
def set_conversation(tenant_id):
req = request.json
conv_id = req.get("id")
if "messages" in req:
req["message"] = req.pop("messages")
if req["message"]:
for message in req["message"]:
if "reference" in message:
req["reference"] = message.pop("reference")
if "assistant_id" in req:
req["dialog_id"] = req.pop("assistant_id")
if "id" in req:
del req["id"]
conv = ConversationService.query(id=conv_id)
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
if req.get("dialog_id"):
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
return get_data_error_result(retmsg="You do not own the assistant")
if "dialog_id" in req and not req.get("dialog_id"):
return get_data_error_result(retmsg="assistant_id can not be empty.")
if "name" in req and not req.get("name"):
return get_data_error_result(retmsg="name can not be empty.")
if "message" in req and not req.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(retmsg="Session updates error")
return get_json_result(data=True)
if not req.get("dialog_id"):
return get_data_error_result(retmsg="assistant_id is required.")
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
return get_data_error_result(retmsg="You do not own the assistant")
conv = {
"id": get_uuid(),
"dialog_id": req["dialog_id"],
"name": req.get("name", "New session"),
"message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]),
"reference": req.get("reference", [])
}
if not conv.get("name"):
return get_data_error_result(retmsg="name can not be empty.")
if not conv.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
if not e:
return get_data_error_result(retmsg="Fail to new session!")
conv = conv.to_dict()
conv["messages"] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id")
for message in conv["messages"]:
message["reference"] = conv.get("reference")
del conv["reference"]
return get_json_result(data=conv)
@manager.route('/completion', methods=['POST'])
@token_required
def completion(tenant_id):
req = request.json
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
# ]}
msg = []
question = {
"content": req.get("question"),
"role": "user",
"id": str(uuid4())
}
req["messages"].append(question)
for m in req["messages"]:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
m["id"] = m.get("id", str(uuid4()))
msg.append(m)
message_id = msg[-1].get("id")
conv = ConversationService.query(id=req["id"])
conv = conv[0]
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
fillin_conv(ans)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
else:
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)

View File

@ -1,4 +1,7 @@
from typing import List
from .base import Base
from .session import Session, Message
class Assistant(Base):
@ -54,3 +57,15 @@ class Assistant(Base):
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
def create_session(self, name: str = "New session", messages: List[Message] = [
{"role": "assistant", "reference": [],
"content": "您好我是您的助手小樱长得可爱又善良can I help you?"}]) -> Session:
res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, })
res = res.json()
if res.get("retmsg") == "success":
return Session(self.rag, res['data'])
raise Exception(res["retmsg"])
def get_prologue(self):
return self.prompt.opener

View File

@ -0,0 +1,64 @@
import json
from .base import Base
class Session(Base):
def __init__(self, rag, res_dict):
self.id = None
self.name = "New session"
self.messages = [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}]
self.assistant_id = None
super().__init__(rag, res_dict)
def chat(self, question: str, stream: bool = False):
res = self.post("/session/completion",
{"id": self.id, "question": question, "stream": stream, "messages": self.messages})
res = res.text
response_lines = res.splitlines()
message_list = []
for line in response_lines:
if line.startswith("data:"):
json_data = json.loads(line[5:])
if json_data["data"] != True:
answer = json_data["data"]["answer"]
reference = json_data["data"]["reference"]
temp_dict = {
"content": answer,
"role": "assistant",
"reference": reference
}
message = Message(self.rag, temp_dict)
message_list.append(message)
return message_list
def save(self):
res = self.post("/session/save",
{"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res.get("retmsg"))
class Message(Base):
def __init__(self, rag, res_dict):
self.content = "您好我是您的助手小樱长得可爱又善良can I help you?"
self.reference = []
self.role = "assistant"
self.prompt=None
super().__init__(rag, res_dict)
class Chunk(Base):
def __init__(self, rag, res_dict):
self.id = None
self.content = None
self.document_id = None
self.document_name = None
self.knowledgebase_id = None
self.image_id = None
self.similarity = None
self.vector_similarity = None
self.term_similarity = None
self.positions = None
super().__init__(rag, res_dict)

View File

@ -17,7 +17,6 @@ from typing import List
import requests
from .modules.chat_assistant import Assistant
from .modules.dataset import DataSet
@ -88,7 +87,7 @@ class RAGFlow:
datasets.append(dataset.to_json())
if llm is None:
llm = Assistant.LLM(self, {"model_name": "deepseek-chat",
llm = Assistant.LLM(self, {"model_name": None,
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,

View File

@ -10,10 +10,10 @@ class TestAssistant(TestSdk):
Test creating an assistant with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.get_dataset(name="God")
assistant = rag.create_assistant("God",knowledgebases=[kb])
kb = rag.create_dataset(name="test_create_assistant")
assistant = rag.create_assistant("test_create", knowledgebases=[kb])
if isinstance(assistant, Assistant):
assert assistant.name == "God", "Name does not match."
assert assistant.name == "test_create", "Name does not match."
else:
assert False, f"Failed to create assistant, error: {assistant}"
@ -22,11 +22,11 @@ class TestAssistant(TestSdk):
Test updating an assistant with success.
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.get_dataset(name="God")
assistant = rag.create_assistant("ABC",knowledgebases=[kb])
kb = rag.create_dataset(name="test_update_assistant")
assistant = rag.create_assistant("test_update", knowledgebases=[kb])
if isinstance(assistant, Assistant):
assert assistant.name == "ABC", "Name does not match."
assistant.name = 'DEF'
assert assistant.name == "test_update", "Name does not match."
assistant.name = 'new_assistant'
res = assistant.save()
assert res is True, f"Failed to update assistant, error: {res}"
else:
@ -37,10 +37,10 @@ class TestAssistant(TestSdk):
Test deleting an assistant with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.get_dataset(name="God")
assistant = rag.create_assistant("MA",knowledgebases=[kb])
kb = rag.create_dataset(name="test_delete_assistant")
assistant = rag.create_assistant("test_delete", knowledgebases=[kb])
if isinstance(assistant, Assistant):
assert assistant.name == "MA", "Name does not match."
assert assistant.name == "test_delete", "Name does not match."
res = assistant.delete()
assert res is True, f"Failed to delete assistant, error: {res}"
else:
@ -61,6 +61,8 @@ class TestAssistant(TestSdk):
Test getting an assistant's detail with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
assistant = rag.get_assistant(name="God")
kb = rag.create_dataset(name="test_get_assistant")
rag.create_assistant("test_get_assistant", knowledgebases=[kb])
assistant = rag.get_assistant(name="test_get_assistant")
assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}."
assert assistant.name == "God", "Name does not match"
assert assistant.name == "test_get_assistant", "Name does not match"

View File

@ -0,0 +1,27 @@
from ragflow import RAGFlow
from common import API_KEY, HOST_ADDRESS
class TestChatSession:
def test_create_session(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.create_dataset(name="test_create_session")
assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb])
session = assistant.create_session()
assert assistant is not None, "Failed to get the assistant."
assert session is not None, "Failed to create a session."
def test_create_chat_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.create_dataset(name="test_create_chat")
assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb])
session = assistant.create_session()
assert session is not None, "Failed to create a session."
prologue = assistant.get_prologue()
assert isinstance(prologue, str), "Prologue is not a string."
assert len(prologue) > 0, "Prologue is empty."
question = "What is AI"
ans = session.chat(question, stream=True)
response = ans[-1].content
assert len(response) > 0, "Assistant did not return any response."