mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 06:09:01 +08:00
SDK for session (#2312)
### What problem does this PR solve? SDK for session #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
parent
ceae4df889
commit
336a639164
@ -16,9 +16,10 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
|
from api.db.db_models import TenantLLM
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db.services.document_service import DocumentService
|
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
@ -30,7 +31,6 @@ from api.utils.api_utils import get_json_result
|
|||||||
@token_required
|
@token_required
|
||||||
def save(tenant_id):
|
def save(tenant_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
id = req.get("id")
|
|
||||||
# dataset
|
# dataset
|
||||||
if req.get("knowledgebases") == []:
|
if req.get("knowledgebases") == []:
|
||||||
return get_data_error_result(retmsg="knowledgebases can not be empty list")
|
return get_data_error_result(retmsg="knowledgebases can not be empty list")
|
||||||
@ -41,8 +41,8 @@ def save(tenant_id):
|
|||||||
return get_data_error_result(retmsg="knowledgebase needs id")
|
return get_data_error_result(retmsg="knowledgebase needs id")
|
||||||
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
|
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
|
||||||
return get_data_error_result(retmsg="you do not own the knowledgebase")
|
return get_data_error_result(retmsg="you do not own the knowledgebase")
|
||||||
if not DocumentService.query(kb_id=kb["id"]):
|
# if not DocumentService.query(kb_id=kb["id"]):
|
||||||
return get_data_error_result(retmsg="There is a invalid knowledgebase")
|
# return get_data_error_result(retmsg="There is a invalid knowledgebase")
|
||||||
kb_list.append(kb["id"])
|
kb_list.append(kb["id"])
|
||||||
req["kb_ids"] = kb_list
|
req["kb_ids"] = kb_list
|
||||||
# llm
|
# llm
|
||||||
@ -72,10 +72,10 @@ def save(tenant_id):
|
|||||||
req[key] = prompt.pop(key)
|
req[key] = prompt.pop(key)
|
||||||
req["prompt_config"] = req.pop("prompt")
|
req["prompt_config"] = req.pop("prompt")
|
||||||
# create
|
# create
|
||||||
if not id:
|
if "id" not in req:
|
||||||
# dataset
|
# dataset
|
||||||
if not kb_list:
|
if not kb_list:
|
||||||
return get_data_error_result(retmsg="knowledgebase is required!")
|
return get_data_error_result(retmsg="knowledgebases are required!")
|
||||||
# init
|
# init
|
||||||
req["id"] = get_uuid()
|
req["id"] = get_uuid()
|
||||||
req["description"] = req.get("description", "A helpful Assistant")
|
req["description"] = req.get("description", "A helpful Assistant")
|
||||||
@ -83,7 +83,11 @@ def save(tenant_id):
|
|||||||
req["top_n"] = req.get("top_n", 6)
|
req["top_n"] = req.get("top_n", 6)
|
||||||
req["top_k"] = req.get("top_k", 1024)
|
req["top_k"] = req.get("top_k", 1024)
|
||||||
req["rerank_id"] = req.get("rerank_id", "")
|
req["rerank_id"] = req.get("rerank_id", "")
|
||||||
req["llm_id"] = req.get("llm_id", tenant.llm_id)
|
if req.get("llm_id"):
|
||||||
|
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
||||||
|
return get_data_error_result(retmsg="the model_name does not exist.")
|
||||||
|
else:
|
||||||
|
req["llm_id"] = tenant.llm_id
|
||||||
if not req.get("name"):
|
if not req.get("name"):
|
||||||
return get_data_error_result(retmsg="name is required.")
|
return get_data_error_result(retmsg="name is required.")
|
||||||
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
@ -149,14 +153,20 @@ def save(tenant_id):
|
|||||||
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
|
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
|
||||||
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
|
||||||
# prompt
|
# prompt
|
||||||
|
if not req["id"]:
|
||||||
|
return get_data_error_result(retmsg="id can not be empty")
|
||||||
e, res = DialogService.get_by_id(req["id"])
|
e, res = DialogService.get_by_id(req["id"])
|
||||||
res = res.to_json()
|
res = res.to_json()
|
||||||
|
if "llm_id" in req:
|
||||||
|
if not TenantLLMService.query(llm_name=req["llm_id"]):
|
||||||
|
return get_data_error_result(retmsg="the model_name does not exist.")
|
||||||
if "name" in req:
|
if "name" in req:
|
||||||
if not req.get("name"):
|
if not req.get("name"):
|
||||||
return get_data_error_result(retmsg="name is not empty.")
|
return get_data_error_result(retmsg="name is not empty.")
|
||||||
if req["name"].lower() != res["name"].lower() \
|
if req["name"].lower() != res["name"].lower() \
|
||||||
and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0:
|
and len(
|
||||||
return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.")
|
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
|
||||||
|
return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.")
|
||||||
if "prompt_config" in req:
|
if "prompt_config" in req:
|
||||||
res["prompt_config"].update(req["prompt_config"])
|
res["prompt_config"].update(req["prompt_config"])
|
||||||
for p in res["prompt_config"]["parameters"]:
|
for p in res["prompt_config"]["parameters"]:
|
||||||
@ -186,7 +196,7 @@ def delete(tenant_id):
|
|||||||
if "id" not in req:
|
if "id" not in req:
|
||||||
return get_data_error_result(retmsg="id is required")
|
return get_data_error_result(retmsg="id is required")
|
||||||
id = req['id']
|
id = req['id']
|
||||||
if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value):
|
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
|
||||||
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
temp_dict = {"status": StatusEnum.INVALID.value}
|
temp_dict = {"status": StatusEnum.INVALID.value}
|
||||||
@ -200,21 +210,22 @@ def get(tenant_id):
|
|||||||
req = request.args
|
req = request.args
|
||||||
if "id" in req:
|
if "id" in req:
|
||||||
id = req["id"]
|
id = req["id"]
|
||||||
ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value)
|
ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value)
|
||||||
if not ass:
|
if not ass:
|
||||||
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
|
||||||
if "name" in req:
|
if "name" in req:
|
||||||
name = req["name"]
|
name = req["name"]
|
||||||
if ass[0].name != name:
|
if ass[0].name != name:
|
||||||
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
|
||||||
res=ass[0].to_json()
|
res = ass[0].to_json()
|
||||||
else:
|
else:
|
||||||
if "name" in req:
|
if "name" in req:
|
||||||
name = req["name"]
|
name = req["name"]
|
||||||
ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value)
|
ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
if not ass:
|
if not ass:
|
||||||
return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR)
|
return get_json_result(data=False, retmsg='You do not own the assistant.',
|
||||||
res=ass[0].to_json()
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
res = ass[0].to_json()
|
||||||
else:
|
else:
|
||||||
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
|
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
|
||||||
renamed_dict = {}
|
renamed_dict = {}
|
||||||
@ -258,7 +269,7 @@ def list_assistants(tenant_id):
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
order_by=DialogService.model.create_time)
|
order_by=DialogService.model.create_time)
|
||||||
assts = [d.to_dict() for d in assts]
|
assts = [d.to_dict() for d in assts]
|
||||||
list_assts=[]
|
list_assts = []
|
||||||
renamed_dict = {}
|
renamed_dict = {}
|
||||||
key_mapping = {"parameters": "variables",
|
key_mapping = {"parameters": "variables",
|
||||||
"prologue": "opener",
|
"prologue": "opener",
|
||||||
|
@ -60,7 +60,7 @@ def save(tenant_id):
|
|||||||
req.update(mapped_keys)
|
req.update(mapped_keys)
|
||||||
if not KnowledgebaseService.save(**req):
|
if not KnowledgebaseService.save(**req):
|
||||||
return get_data_error_result(retmsg="Create dataset error.(Database error)")
|
return get_data_error_result(retmsg="Create dataset error.(Database error)")
|
||||||
renamed_data={}
|
renamed_data = {}
|
||||||
e, k = KnowledgebaseService.get_by_id(req["id"])
|
e, k = KnowledgebaseService.get_by_id(req["id"])
|
||||||
for key, value in k.to_dict().items():
|
for key, value in k.to_dict().items():
|
||||||
new_key = key_mapping.get(key, key)
|
new_key = key_mapping.get(key, key)
|
||||||
@ -88,6 +88,9 @@ def save(tenant_id):
|
|||||||
data=False, retmsg='You do not own the dataset.',
|
data=False, retmsg='You do not own the dataset.',
|
||||||
retcode=RetCode.OPERATING_ERROR)
|
retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
if not req["id"]:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="id can not be empty.")
|
||||||
e, kb = KnowledgebaseService.get_by_id(req["id"])
|
e, kb = KnowledgebaseService.get_by_id(req["id"])
|
||||||
|
|
||||||
if "chunk_count" in req:
|
if "chunk_count" in req:
|
||||||
@ -108,6 +111,7 @@ def save(tenant_id):
|
|||||||
retmsg="If chunk count is not 0, parse method is not changable.")
|
retmsg="If chunk count is not 0, parse method is not changable.")
|
||||||
req['parser_id'] = req.pop('parse_method')
|
req['parser_id'] = req.pop('parse_method')
|
||||||
if "name" in req:
|
if "name" in req:
|
||||||
|
req["name"] = req["name"].strip()
|
||||||
if req["name"].lower() != kb.name.lower() \
|
if req["name"].lower() != kb.name.lower() \
|
||||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
|
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
|
||||||
status=StatusEnum.VALID.value)) > 0:
|
status=StatusEnum.VALID.value)) > 0:
|
||||||
|
168
api/apps/sdk/session.py
Normal file
168
api/apps/sdk/session.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from flask import request, Response
|
||||||
|
|
||||||
|
from api.db import StatusEnum
|
||||||
|
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import get_data_error_result
|
||||||
|
from api.utils.api_utils import get_json_result, token_required
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/save', methods=['POST'])
|
||||||
|
@token_required
|
||||||
|
def set_conversation(tenant_id):
|
||||||
|
req = request.json
|
||||||
|
conv_id = req.get("id")
|
||||||
|
if "messages" in req:
|
||||||
|
req["message"] = req.pop("messages")
|
||||||
|
if req["message"]:
|
||||||
|
for message in req["message"]:
|
||||||
|
if "reference" in message:
|
||||||
|
req["reference"] = message.pop("reference")
|
||||||
|
if "assistant_id" in req:
|
||||||
|
req["dialog_id"] = req.pop("assistant_id")
|
||||||
|
if "id" in req:
|
||||||
|
del req["id"]
|
||||||
|
conv = ConversationService.query(id=conv_id)
|
||||||
|
if not conv:
|
||||||
|
return get_data_error_result(retmsg="Session does not exist")
|
||||||
|
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
|
return get_data_error_result(retmsg="You do not own the session")
|
||||||
|
if req.get("dialog_id"):
|
||||||
|
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||||
|
if not dia:
|
||||||
|
return get_data_error_result(retmsg="You do not own the assistant")
|
||||||
|
if "dialog_id" in req and not req.get("dialog_id"):
|
||||||
|
return get_data_error_result(retmsg="assistant_id can not be empty.")
|
||||||
|
if "name" in req and not req.get("name"):
|
||||||
|
return get_data_error_result(retmsg="name can not be empty.")
|
||||||
|
if "message" in req and not req.get("message"):
|
||||||
|
return get_data_error_result(retmsg="messages can not be empty")
|
||||||
|
if not ConversationService.update_by_id(conv_id, req):
|
||||||
|
return get_data_error_result(retmsg="Session updates error")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
if not req.get("dialog_id"):
|
||||||
|
return get_data_error_result(retmsg="assistant_id is required.")
|
||||||
|
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||||
|
if not dia:
|
||||||
|
return get_data_error_result(retmsg="You do not own the assistant")
|
||||||
|
conv = {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"dialog_id": req["dialog_id"],
|
||||||
|
"name": req.get("name", "New session"),
|
||||||
|
"message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]),
|
||||||
|
"reference": req.get("reference", [])
|
||||||
|
}
|
||||||
|
if not conv.get("name"):
|
||||||
|
return get_data_error_result(retmsg="name can not be empty.")
|
||||||
|
if not conv.get("message"):
|
||||||
|
return get_data_error_result(retmsg="messages can not be empty")
|
||||||
|
ConversationService.save(**conv)
|
||||||
|
e, conv = ConversationService.get_by_id(conv["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Fail to new session!")
|
||||||
|
conv = conv.to_dict()
|
||||||
|
conv["messages"] = conv.pop("message")
|
||||||
|
conv["assistant_id"] = conv.pop("dialog_id")
|
||||||
|
for message in conv["messages"]:
|
||||||
|
message["reference"] = conv.get("reference")
|
||||||
|
del conv["reference"]
|
||||||
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/completion', methods=['POST'])
|
||||||
|
@token_required
|
||||||
|
def completion(tenant_id):
|
||||||
|
req = request.json
|
||||||
|
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
||||||
|
# {"role": "user", "content": "上海有吗?"}
|
||||||
|
# ]}
|
||||||
|
msg = []
|
||||||
|
question = {
|
||||||
|
"content": req.get("question"),
|
||||||
|
"role": "user",
|
||||||
|
"id": str(uuid4())
|
||||||
|
}
|
||||||
|
req["messages"].append(question)
|
||||||
|
for m in req["messages"]:
|
||||||
|
if m["role"] == "system": continue
|
||||||
|
if m["role"] == "assistant" and not msg: continue
|
||||||
|
m["id"] = m.get("id", str(uuid4()))
|
||||||
|
msg.append(m)
|
||||||
|
message_id = msg[-1].get("id")
|
||||||
|
conv = ConversationService.query(id=req["id"])
|
||||||
|
conv = conv[0]
|
||||||
|
if not conv:
|
||||||
|
return get_data_error_result(retmsg="Session does not exist")
|
||||||
|
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||||
|
return get_data_error_result(retmsg="You do not own the session")
|
||||||
|
conv.message = deepcopy(req["messages"])
|
||||||
|
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Dialog not found!")
|
||||||
|
del req["id"]
|
||||||
|
del req["messages"]
|
||||||
|
|
||||||
|
if not conv.reference:
|
||||||
|
conv.reference = []
|
||||||
|
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||||
|
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||||
|
|
||||||
|
def fillin_conv(ans):
|
||||||
|
nonlocal conv, message_id
|
||||||
|
if not conv.reference:
|
||||||
|
conv.reference.append(ans["reference"])
|
||||||
|
else:
|
||||||
|
conv.reference[-1] = ans["reference"]
|
||||||
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
||||||
|
"id": message_id, "prompt": ans.get("prompt", "")}
|
||||||
|
ans["id"] = message_id
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
nonlocal dia, msg, req, conv
|
||||||
|
try:
|
||||||
|
for ans in chat(dia, msg, **req):
|
||||||
|
fillin_conv(ans)
|
||||||
|
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
||||||
|
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||||
|
ensure_ascii=False) + "\n\n"
|
||||||
|
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
if req.get("stream", True):
|
||||||
|
resp = Response(stream(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
else:
|
||||||
|
answer = None
|
||||||
|
for ans in chat(dia, msg, **req):
|
||||||
|
answer = ans
|
||||||
|
fillin_conv(ans)
|
||||||
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
break
|
||||||
|
return get_json_result(data=answer)
|
@ -1,9 +1,12 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
|
from .session import Session, Message
|
||||||
|
|
||||||
|
|
||||||
class Assistant(Base):
|
class Assistant(Base):
|
||||||
def __init__(self, rag, res_dict):
|
def __init__(self, rag, res_dict):
|
||||||
self.id=""
|
self.id = ""
|
||||||
self.name = "assistant"
|
self.name = "assistant"
|
||||||
self.avatar = "path/to/avatar"
|
self.avatar = "path/to/avatar"
|
||||||
self.knowledgebases = ["kb1"]
|
self.knowledgebases = ["kb1"]
|
||||||
@ -41,8 +44,8 @@ class Assistant(Base):
|
|||||||
|
|
||||||
def save(self) -> bool:
|
def save(self) -> bool:
|
||||||
res = self.post('/assistant/save',
|
res = self.post('/assistant/save',
|
||||||
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases,
|
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases,
|
||||||
"llm":self.llm.to_json(),"prompt":self.prompt.to_json()
|
"llm": self.llm.to_json(), "prompt": self.prompt.to_json()
|
||||||
})
|
})
|
||||||
res = res.json()
|
res = res.json()
|
||||||
if res.get("retmsg") == "success": return True
|
if res.get("retmsg") == "success": return True
|
||||||
@ -54,3 +57,15 @@ class Assistant(Base):
|
|||||||
res = res.json()
|
res = res.json()
|
||||||
if res.get("retmsg") == "success": return True
|
if res.get("retmsg") == "success": return True
|
||||||
raise Exception(res["retmsg"])
|
raise Exception(res["retmsg"])
|
||||||
|
|
||||||
|
def create_session(self, name: str = "New session", messages: List[Message] = [
|
||||||
|
{"role": "assistant", "reference": [],
|
||||||
|
"content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session:
|
||||||
|
res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, })
|
||||||
|
res = res.json()
|
||||||
|
if res.get("retmsg") == "success":
|
||||||
|
return Session(self.rag, res['data'])
|
||||||
|
raise Exception(res["retmsg"])
|
||||||
|
|
||||||
|
def get_prologue(self):
|
||||||
|
return self.prompt.opener
|
||||||
|
64
sdk/python/ragflow/modules/session.py
Normal file
64
sdk/python/ragflow/modules/session.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Session(Base):
|
||||||
|
def __init__(self, rag, res_dict):
|
||||||
|
self.id = None
|
||||||
|
self.name = "New session"
|
||||||
|
self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
|
||||||
|
|
||||||
|
self.assistant_id = None
|
||||||
|
super().__init__(rag, res_dict)
|
||||||
|
|
||||||
|
def chat(self, question: str, stream: bool = False):
|
||||||
|
res = self.post("/session/completion",
|
||||||
|
{"id": self.id, "question": question, "stream": stream, "messages": self.messages})
|
||||||
|
res = res.text
|
||||||
|
response_lines = res.splitlines()
|
||||||
|
message_list = []
|
||||||
|
for line in response_lines:
|
||||||
|
if line.startswith("data:"):
|
||||||
|
json_data = json.loads(line[5:])
|
||||||
|
if json_data["data"] != True:
|
||||||
|
answer = json_data["data"]["answer"]
|
||||||
|
reference = json_data["data"]["reference"]
|
||||||
|
temp_dict = {
|
||||||
|
"content": answer,
|
||||||
|
"role": "assistant",
|
||||||
|
"reference": reference
|
||||||
|
}
|
||||||
|
message = Message(self.rag, temp_dict)
|
||||||
|
message_list.append(message)
|
||||||
|
return message_list
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
res = self.post("/session/save",
|
||||||
|
{"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages})
|
||||||
|
res = res.json()
|
||||||
|
if res.get("retmsg") == "success": return True
|
||||||
|
raise Exception(res.get("retmsg"))
|
||||||
|
|
||||||
|
class Message(Base):
|
||||||
|
def __init__(self, rag, res_dict):
|
||||||
|
self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?"
|
||||||
|
self.reference = []
|
||||||
|
self.role = "assistant"
|
||||||
|
self.prompt=None
|
||||||
|
super().__init__(rag, res_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(Base):
|
||||||
|
def __init__(self, rag, res_dict):
|
||||||
|
self.id = None
|
||||||
|
self.content = None
|
||||||
|
self.document_id = None
|
||||||
|
self.document_name = None
|
||||||
|
self.knowledgebase_id = None
|
||||||
|
self.image_id = None
|
||||||
|
self.similarity = None
|
||||||
|
self.vector_similarity = None
|
||||||
|
self.term_similarity = None
|
||||||
|
self.positions = None
|
||||||
|
super().__init__(rag, res_dict)
|
@ -17,7 +17,6 @@ from typing import List
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
from .modules.chat_assistant import Assistant
|
from .modules.chat_assistant import Assistant
|
||||||
from .modules.dataset import DataSet
|
from .modules.dataset import DataSet
|
||||||
|
|
||||||
@ -88,7 +87,7 @@ class RAGFlow:
|
|||||||
datasets.append(dataset.to_json())
|
datasets.append(dataset.to_json())
|
||||||
|
|
||||||
if llm is None:
|
if llm is None:
|
||||||
llm = Assistant.LLM(self, {"model_name": "deepseek-chat",
|
llm = Assistant.LLM(self, {"model_name": None,
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"top_p": 0.3,
|
"top_p": 0.3,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0.4,
|
||||||
@ -142,4 +141,4 @@ class RAGFlow:
|
|||||||
for data in res['data']:
|
for data in res['data']:
|
||||||
result_list.append(Assistant(self, data))
|
result_list.append(Assistant(self, data))
|
||||||
return result_list
|
return result_list
|
||||||
raise Exception(res["retmsg"])
|
raise Exception(res["retmsg"])
|
||||||
|
@ -10,10 +10,10 @@ class TestAssistant(TestSdk):
|
|||||||
Test creating an assistant with success
|
Test creating an assistant with success
|
||||||
"""
|
"""
|
||||||
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
kb = rag.get_dataset(name="God")
|
kb = rag.create_dataset(name="test_create_assistant")
|
||||||
assistant = rag.create_assistant("God",knowledgebases=[kb])
|
assistant = rag.create_assistant("test_create", knowledgebases=[kb])
|
||||||
if isinstance(assistant, Assistant):
|
if isinstance(assistant, Assistant):
|
||||||
assert assistant.name == "God", "Name does not match."
|
assert assistant.name == "test_create", "Name does not match."
|
||||||
else:
|
else:
|
||||||
assert False, f"Failed to create assistant, error: {assistant}"
|
assert False, f"Failed to create assistant, error: {assistant}"
|
||||||
|
|
||||||
@ -22,11 +22,11 @@ class TestAssistant(TestSdk):
|
|||||||
Test updating an assistant with success.
|
Test updating an assistant with success.
|
||||||
"""
|
"""
|
||||||
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
kb = rag.get_dataset(name="God")
|
kb = rag.create_dataset(name="test_update_assistant")
|
||||||
assistant = rag.create_assistant("ABC",knowledgebases=[kb])
|
assistant = rag.create_assistant("test_update", knowledgebases=[kb])
|
||||||
if isinstance(assistant, Assistant):
|
if isinstance(assistant, Assistant):
|
||||||
assert assistant.name == "ABC", "Name does not match."
|
assert assistant.name == "test_update", "Name does not match."
|
||||||
assistant.name = 'DEF'
|
assistant.name = 'new_assistant'
|
||||||
res = assistant.save()
|
res = assistant.save()
|
||||||
assert res is True, f"Failed to update assistant, error: {res}"
|
assert res is True, f"Failed to update assistant, error: {res}"
|
||||||
else:
|
else:
|
||||||
@ -37,10 +37,10 @@ class TestAssistant(TestSdk):
|
|||||||
Test deleting an assistant with success
|
Test deleting an assistant with success
|
||||||
"""
|
"""
|
||||||
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
kb = rag.get_dataset(name="God")
|
kb = rag.create_dataset(name="test_delete_assistant")
|
||||||
assistant = rag.create_assistant("MA",knowledgebases=[kb])
|
assistant = rag.create_assistant("test_delete", knowledgebases=[kb])
|
||||||
if isinstance(assistant, Assistant):
|
if isinstance(assistant, Assistant):
|
||||||
assert assistant.name == "MA", "Name does not match."
|
assert assistant.name == "test_delete", "Name does not match."
|
||||||
res = assistant.delete()
|
res = assistant.delete()
|
||||||
assert res is True, f"Failed to delete assistant, error: {res}"
|
assert res is True, f"Failed to delete assistant, error: {res}"
|
||||||
else:
|
else:
|
||||||
@ -61,6 +61,8 @@ class TestAssistant(TestSdk):
|
|||||||
Test getting an assistant's detail with success
|
Test getting an assistant's detail with success
|
||||||
"""
|
"""
|
||||||
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
assistant = rag.get_assistant(name="God")
|
kb = rag.create_dataset(name="test_get_assistant")
|
||||||
|
rag.create_assistant("test_get_assistant", knowledgebases=[kb])
|
||||||
|
assistant = rag.get_assistant(name="test_get_assistant")
|
||||||
assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}."
|
assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}."
|
||||||
assert assistant.name == "God", "Name does not match"
|
assert assistant.name == "test_get_assistant", "Name does not match"
|
||||||
|
27
sdk/python/test/t_session.py
Normal file
27
sdk/python/test/t_session.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from ragflow import RAGFlow
|
||||||
|
|
||||||
|
from common import API_KEY, HOST_ADDRESS
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatSession:
|
||||||
|
def test_create_session(self):
|
||||||
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
|
kb = rag.create_dataset(name="test_create_session")
|
||||||
|
assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb])
|
||||||
|
session = assistant.create_session()
|
||||||
|
assert assistant is not None, "Failed to get the assistant."
|
||||||
|
assert session is not None, "Failed to create a session."
|
||||||
|
|
||||||
|
def test_create_chat_with_success(self):
|
||||||
|
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||||
|
kb = rag.create_dataset(name="test_create_chat")
|
||||||
|
assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb])
|
||||||
|
session = assistant.create_session()
|
||||||
|
assert session is not None, "Failed to create a session."
|
||||||
|
prologue = assistant.get_prologue()
|
||||||
|
assert isinstance(prologue, str), "Prologue is not a string."
|
||||||
|
assert len(prologue) > 0, "Prologue is empty."
|
||||||
|
question = "What is AI"
|
||||||
|
ans = session.chat(question, stream=True)
|
||||||
|
response = ans[-1].content
|
||||||
|
assert len(response) > 0, "Assistant did not return any response."
|
Loading…
x
Reference in New Issue
Block a user