SDK for session (#2354)

### What problem does this PR solve?

Includes SDK for creating, updating sessions, getting sessions, listing
sessions, and dialogues
#1102 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua 2024-09-11 12:03:55 +08:00 committed by GitHub
parent 7fad48f42c
commit 1fc14ff6d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 298 additions and 131 deletions

View File

@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
# #
import json import json
from copy import deepcopy
from uuid import uuid4 from uuid import uuid4
from flask import request, Response from flask import request, Response
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result, token_required from api.utils.api_utils import get_json_result, token_required
@ -31,12 +31,6 @@ from api.utils.api_utils import get_json_result, token_required
def set_conversation(tenant_id): def set_conversation(tenant_id):
req = request.json req = request.json
conv_id = req.get("id") conv_id = req.get("id")
if "messages" in req:
req["message"] = req.pop("messages")
if req["message"]:
for message in req["message"]:
if "reference" in message:
req["reference"] = message.pop("reference")
if "assistant_id" in req: if "assistant_id" in req:
req["dialog_id"] = req.pop("assistant_id") req["dialog_id"] = req.pop("assistant_id")
if "id" in req: if "id" in req:
@ -52,10 +46,12 @@ def set_conversation(tenant_id):
return get_data_error_result(retmsg="You do not own the assistant") return get_data_error_result(retmsg="You do not own the assistant")
if "dialog_id" in req and not req.get("dialog_id"): if "dialog_id" in req and not req.get("dialog_id"):
return get_data_error_result(retmsg="assistant_id can not be empty.") return get_data_error_result(retmsg="assistant_id can not be empty.")
if "message" in req:
return get_data_error_result(retmsg="message can not be change")
if "reference" in req:
return get_data_error_result(retmsg="reference can not be change")
if "name" in req and not req.get("name"): if "name" in req and not req.get("name"):
return get_data_error_result(retmsg="name can not be empty.") return get_data_error_result(retmsg="name can not be empty.")
if "message" in req and not req.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
if not ConversationService.update_by_id(conv_id, req): if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(retmsg="Session updates error") return get_data_error_result(retmsg="Session updates error")
return get_json_result(data=True) return get_json_result(data=True)
@ -69,22 +65,17 @@ def set_conversation(tenant_id):
"id": get_uuid(), "id": get_uuid(),
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
"name": req.get("name", "New session"), "name": req.get("name", "New session"),
"message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]), "message": [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}]
"reference": req.get("reference", [])
} }
if not conv.get("name"): if not conv.get("name"):
return get_data_error_result(retmsg="name can not be empty.") return get_data_error_result(retmsg="name can not be empty.")
if not conv.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
ConversationService.save(**conv) ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"]) e, conv = ConversationService.get_by_id(conv["id"])
if not e: if not e:
return get_data_error_result(retmsg="Fail to new session!") return get_data_error_result(retmsg="Fail to new session!")
conv = conv.to_dict() conv = conv.to_dict()
conv["messages"] = conv.pop("message") conv['messages'] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id") conv["assistant_id"] = conv.pop("dialog_id")
for message in conv["messages"]:
message["reference"] = conv.get("reference")
del conv["reference"] del conv["reference"]
return get_json_result(data=conv) return get_json_result(data=conv)
@ -96,31 +87,28 @@ def completion(tenant_id):
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"} # {"role": "user", "content": "上海有吗?"}
# ]} # ]}
if "id" not in req:
return get_data_error_result(retmsg="id is required")
conv = ConversationService.query(id=req["id"])
if not conv:
return get_data_error_result(retmsg="Session does not exist")
conv = conv[0]
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
msg = [] msg = []
question = { question = {
"content": req.get("question"), "content": req.get("question"),
"role": "user", "role": "user",
"id": str(uuid4()) "id": str(uuid4())
} }
req["messages"].append(question) conv.message.append(question)
for m in req["messages"]: for m in conv.message:
if m["role"] == "system": continue if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue if m["role"] == "assistant" and not msg: continue
m["id"] = m.get("id", str(uuid4()))
msg.append(m) msg.append(m)
message_id = msg[-1].get("id") message_id = msg[-1].get("id")
conv = ConversationService.query(id=req["id"])
conv = conv[0]
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(conv.dialog_id) e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["id"] del req["id"]
del req["messages"]
if not conv.reference: if not conv.reference:
conv.reference = [] conv.reference = []
@ -166,3 +154,110 @@ def completion(tenant_id):
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
break break
return get_json_result(data=answer) return get_json_result(data=answer)
@manager.route('/get', methods=['GET'])
@token_required
def get(tenant_id):
req = request.args
if "id" not in req:
return get_data_error_result(retmsg="id is required")
conv_id = req["id"]
conv = ConversationService.query(id=conv_id)
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
conv = conv[0].to_dict()
conv['messages'] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id")
if conv["reference"]:
messages = conv["messages"]
message_num = 0
chunk_num = 0
while message_num < len(messages):
if message_num != 0 and messages[message_num]["role"] != "user":
chunk_list = []
if "chunks" in conv["reference"][chunk_num]:
chunks = conv["reference"][chunk_num]["chunks"]
for chunk in chunks:
new_chunk = {
"id": chunk["chunk_id"],
"content": chunk["content_with_weight"],
"document_id": chunk["doc_id"],
"document_name": chunk["docnm_kwd"],
"knowledgebase_id": chunk["kb_id"],
"image_id": chunk["img_id"],
"similarity": chunk["similarity"],
"vector_similarity": chunk["vector_similarity"],
"term_similarity": chunk["term_similarity"],
"positions": chunk["positions"],
}
chunk_list.append(new_chunk)
chunk_num += 1
messages[message_num]["reference"] = chunk_list
message_num += 1
del conv["reference"]
return get_json_result(data=conv)
@manager.route('/list', methods=["GET"])
@token_required
def list(tenant_id):
assistant_id = request.args["assistant_id"]
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
return get_json_result(
data=False, retmsg=f'Only owner of the assistant is authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=assistant_id,
order_by=ConversationService.model.create_time,
reverse=True)
convs = [d.to_dict() for d in convs]
for conv in convs:
conv['messages'] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id")
if conv["reference"]:
messages = conv["messages"]
message_num = 0
chunk_num = 0
while message_num < len(messages):
if message_num != 0 and messages[message_num]["role"] != "user":
chunk_list = []
if "chunks" in conv["reference"][chunk_num]:
chunks = conv["reference"][chunk_num]["chunks"]
for chunk in chunks:
new_chunk = {
"id": chunk["chunk_id"],
"content": chunk["content_with_weight"],
"document_id": chunk["doc_id"],
"document_name": chunk["docnm_kwd"],
"knowledgebase_id": chunk["kb_id"],
"image_id": chunk["img_id"],
"similarity": chunk["similarity"],
"vector_similarity": chunk["vector_similarity"],
"term_similarity": chunk["term_similarity"],
"positions": chunk["positions"],
}
chunk_list.append(new_chunk)
chunk_num += 1
messages[message_num]["reference"] = chunk_list
message_num += 1
del conv["reference"]
return get_json_result(data=convs)
@manager.route('/delete', methods=["DELETE"])
@token_required
def delete(tenant_id):
id = request.args.get("id")
if not id:
return get_data_error_result(retmsg="`id` is required in deleting operation")
conv = ConversationService.query(id=id)
if not conv:
return get_data_error_result(retmsg="Session doesn't exist")
conv = conv[0]
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You don't own the session")
ConversationService.delete_by_id(id)
return get_json_result(data=True)

View File

@ -4,4 +4,5 @@ __version__ = importlib.metadata.version("ragflow")
from .ragflow import RAGFlow from .ragflow import RAGFlow
from .modules.dataset import DataSet from .modules.dataset import DataSet
from .modules.chat_assistant import Assistant from .modules.assistant import Assistant
from .modules.session import Session

View File

@ -1,71 +1,86 @@
from typing import List from typing import List
from .base import Base from .base import Base
from .session import Session, Message from .session import Session
class Assistant(Base): class Assistant(Base):
def __init__(self, rag, res_dict): def __init__(self, rag, res_dict):
self.id = "" self.id = ""
self.name = "assistant" self.name = "assistant"
self.avatar = "path/to/avatar" self.avatar = "path/to/avatar"
self.knowledgebases = ["kb1"] self.knowledgebases = ["kb1"]
self.llm = Assistant.LLM(rag, {}) self.llm = Assistant.LLM(rag, {})
self.prompt = Assistant.Prompt(rag, {}) self.prompt = Assistant.Prompt(rag, {})
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
class LLM(Base): class LLM(Base):
def __init__(self, rag, res_dict): def __init__(self, rag, res_dict):
self.model_name = "deepseek-chat" self.model_name = "deepseek-chat"
self.temperature = 0.1 self.temperature = 0.1
self.top_p = 0.3 self.top_p = 0.3
self.presence_penalty = 0.4 self.presence_penalty = 0.4
self.frequency_penalty = 0.7 self.frequency_penalty = 0.7
self.max_tokens = 512 self.max_tokens = 512
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
class Prompt(Base): class Prompt(Base):
def __init__(self, rag, res_dict): def __init__(self, rag, res_dict):
self.similarity_threshold = 0.2 self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.7 self.keywords_similarity_weight = 0.7
self.top_n = 8 self.top_n = 8
self.variables = [{"key": "knowledge", "optional": True}] self.variables = [{"key": "knowledge", "optional": True}]
self.rerank_model = None self.rerank_model = None
self.empty_response = None self.empty_response = None
self.opener = "Hi! I'm your assistant, what can I do for you?" self.opener = "Hi! I'm your assistant, what can I do for you?"
self.show_quote = True self.show_quote = True
self.prompt = ( self.prompt = (
"You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
"Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
"your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
) )
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def save(self) -> bool: def save(self) -> bool:
res = self.post('/assistant/save', res = self.post('/assistant/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases,
"llm": self.llm.to_json(), "prompt": self.prompt.to_json() "llm": self.llm.to_json(), "prompt": self.prompt.to_json()
}) })
res = res.json() res = res.json()
if res.get("retmsg") == "success": return True if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def delete(self) -> bool: def delete(self) -> bool:
res = self.rm('/assistant/delete', res = self.rm('/assistant/delete',
{"id": self.id}) {"id": self.id})
res = res.json() res = res.json()
if res.get("retmsg") == "success": return True if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def create_session(self, name: str = "New session", messages: List[Message] = [ def create_session(self, name: str = "New session") -> Session:
{"role": "assistant", "reference": [], res = self.post("/session/save", {"name": name, "assistant_id": self.id})
"content": "您好我是您的助手小樱长得可爱又善良can I help you?"}]) -> Session: res = res.json()
res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, }) if res.get("retmsg") == "success":
res = res.json() return Session(self.rag, res['data'])
if res.get("retmsg") == "success": raise Exception(res["retmsg"])
return Session(self.rag, res['data'])
raise Exception(res["retmsg"]) def list_session(self) -> List[Session]:
res = self.get('/session/list', {"assistant_id": self.id})
def get_prologue(self): res = res.json()
return self.prompt.opener if res.get("retmsg") == "success":
result_list = []
for data in res["data"]:
result_list.append(Session(self.rag, data))
return result_list
raise Exception(res["retmsg"])
def get_session(self, id) -> Session:
res = self.get("/session/get", {"id": id})
res = res.json()
if res.get("retmsg") == "success":
return Session(self.rag, res["data"])
raise Exception(res["retmsg"])
def get_prologue(self):
return self.prompt.opener

View File

@ -18,8 +18,8 @@ class Base(object):
pr[name] = value pr[name] = value
return pr return pr
def post(self, path, param): def post(self, path, param, stream=False):
res = self.rag.post(path, param) res = self.rag.post(path, param, stream=stream)
return res return res
def get(self, path, params): def get(self, path, params):

View File

@ -8,17 +8,17 @@ class Session(Base):
self.id = None self.id = None
self.name = "New session" self.name = "New session"
self.messages = [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}] self.messages = [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}]
self.assistant_id = None self.assistant_id = None
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def chat(self, question: str, stream: bool = False): def chat(self, question: str, stream: bool = False):
for message in self.messages:
if "reference" in message:
message.pop("reference")
res = self.post("/session/completion", res = self.post("/session/completion",
{"id": self.id, "question": question, "stream": stream, "messages": self.messages}) {"id": self.id, "question": question, "stream": stream}, stream=True)
res = res.text for line in res.iter_lines():
response_lines = res.splitlines() line = line.decode("utf-8")
message_list = []
for line in response_lines:
if line.startswith("data:"): if line.startswith("data:"):
json_data = json.loads(line[5:]) json_data = json.loads(line[5:])
if json_data["data"] != True: if json_data["data"] != True:
@ -26,26 +26,49 @@ class Session(Base):
reference = json_data["data"]["reference"] reference = json_data["data"]["reference"]
temp_dict = { temp_dict = {
"content": answer, "content": answer,
"role": "assistant", "role": "assistant"
"reference": reference
} }
if "chunks" in reference:
chunks = reference["chunks"]
chunk_list = []
for chunk in chunks:
new_chunk = {
"id": chunk["chunk_id"],
"content": chunk["content_with_weight"],
"document_id": chunk["doc_id"],
"document_name": chunk["docnm_kwd"],
"knowledgebase_id": chunk["kb_id"],
"image_id": chunk["img_id"],
"similarity": chunk["similarity"],
"vector_similarity": chunk["vector_similarity"],
"term_similarity": chunk["term_similarity"],
"positions": chunk["positions"],
}
chunk_list.append(new_chunk)
temp_dict["reference"] = chunk_list
message = Message(self.rag, temp_dict) message = Message(self.rag, temp_dict)
message_list.append(message) yield message
return message_list
def save(self): def save(self):
res = self.post("/session/save", res = self.post("/session/save",
{"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages}) {"id": self.id, "assistant_id": self.assistant_id, "name": self.name})
res = res.json() res = res.json()
if res.get("retmsg") == "success": return True if res.get("retmsg") == "success": return True
raise Exception(res.get("retmsg")) raise Exception(res.get("retmsg"))
def delete(self):
res = self.rm("/session/delete", {"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res.get("retmsg"))
class Message(Base): class Message(Base):
def __init__(self, rag, res_dict): def __init__(self, rag, res_dict):
self.content = "您好我是您的助手小樱长得可爱又善良can I help you?" self.content = "Hi! I am your assistantcan I help you?"
self.reference = [] self.reference = None
self.role = "assistant" self.role = "assistant"
self.prompt=None self.prompt = None
super().__init__(rag, res_dict) super().__init__(rag, res_dict)

View File

@ -17,7 +17,7 @@ from typing import List
import requests import requests
from .modules.chat_assistant import Assistant from .modules.assistant import Assistant
from .modules.dataset import DataSet from .modules.dataset import DataSet
@ -30,8 +30,8 @@ class RAGFlow:
self.api_url = f"{base_url}/api/{version}" self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, param): def post(self, path, param, stream=False):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream)
return res return res
def get(self, path, params=None): def get(self, path, params=None):

View File

@ -1,27 +1,60 @@
from ragflow import RAGFlow from ragflow import RAGFlow,Session
from common import API_KEY, HOST_ADDRESS from common import API_KEY, HOST_ADDRESS
class TestChatSession: class TestSession:
def test_create_session(self): def test_create_session(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.create_dataset(name="test_create_session") kb = rag.create_dataset(name="test_create_session")
assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb])
session = assistant.create_session() session = assistant.create_session()
assert assistant is not None, "Failed to get the assistant." assert isinstance(session,Session), "Failed to create a session."
assert session is not None, "Failed to create a session."
def test_create_chat_with_success(self): def test_create_chat_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.create_dataset(name="test_create_chat") kb = rag.create_dataset(name="test_create_chat")
assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb])
session = assistant.create_session() session = assistant.create_session()
assert session is not None, "Failed to create a session."
prologue = assistant.get_prologue()
assert isinstance(prologue, str), "Prologue is not a string."
assert len(prologue) > 0, "Prologue is empty."
question = "What is AI" question = "What is AI"
ans = session.chat(question, stream=True) for ans in session.chat(question, stream=True):
response = ans[-1].content pass
assert len(response) > 0, "Assistant did not return any response." assert ans.content!="\n**ERROR**", "Please check this error."
def test_delete_session_with_success(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
kb = rag.create_dataset(name="test_delete_session")
assistant = rag.create_assistant(name="test_delete_session",knowledgebases=[kb])
session=assistant.create_session()
res=session.delete()
assert res, "Failed to delete the dataset."
def test_update_session_with_success(self):
rag=RAGFlow(API_KEY,HOST_ADDRESS)
kb=rag.create_dataset(name="test_update_session")
assistant = rag.create_assistant(name="test_update_session",knowledgebases=[kb])
session=assistant.create_session(name="old session")
session.name="new session"
res=session.save()
assert res,"Failed to update the session"
def test_get_session_with_success(self):
rag=RAGFlow(API_KEY,HOST_ADDRESS)
kb=rag.create_dataset(name="test_get_session")
assistant = rag.create_assistant(name="test_get_session",knowledgebases=[kb])
session = assistant.create_session()
session_2= assistant.get_session(id=session.id)
assert session.to_json()==session_2.to_json(),"Failed to get the session"
def test_list_session_with_success(self):
rag=RAGFlow(API_KEY,HOST_ADDRESS)
kb=rag.create_dataset(name="test_list_session")
assistant=rag.create_assistant(name="test_list_session",knowledgebases=[kb])
assistant.create_session("test_1")
assistant.create_session("test_2")
sessions=assistant.list_session()
if isinstance(sessions,list):
for session in sessions:
assert isinstance(session,Session),"Non-Session elements exist in the list"
else :
assert False,"Failed to retrieve the session list."