mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
Refa: change LLM chat output from full to delta (incremental) (#6534)
### What problem does this PR solve? Change LLM chat output from full to delta (incremental) ### Type of change - [x] Refactoring
This commit is contained in:
parent
6599db1e99
commit
df3890827d
@ -13,31 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.dialog_service import ask, chat
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import StatusEnum
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, validate_request
|
||||
from api.utils.api_utils import get_result, token_required
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request
|
||||
|
||||
from flask import jsonify, request, Response
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
@ -50,7 +48,7 @@ def create(tenant_id, chat_id):
|
||||
"dialog_id": req["dialog_id"],
|
||||
"name": req.get("name", "New session"),
|
||||
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
|
||||
"user_id": req.get("user_id", "")
|
||||
"user_id": req.get("user_id", ""),
|
||||
}
|
||||
if not conv.get("name"):
|
||||
return get_error_data_result(message="`name` can not be empty.")
|
||||
@ -59,20 +57,20 @@ def create(tenant_id, chat_id):
|
||||
if not e:
|
||||
return get_error_data_result(message="Fail to create a session!")
|
||||
conv = conv.to_dict()
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
conv["chat_id"] = conv.pop("dialog_id")
|
||||
del conv["reference"]
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
if not request.is_json:
|
||||
req = request.form
|
||||
files = request.files
|
||||
user_id = request.args.get('user_id', '')
|
||||
user_id = request.args.get("user_id", "")
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
@ -113,7 +111,7 @@ def create_agent_session(tenant_id, agent_id):
|
||||
ele.pop("value")
|
||||
else:
|
||||
if req is not None and req.get(ele["key"]):
|
||||
ele["value"] = req[ele['key']]
|
||||
ele["value"] = req[ele["key"]]
|
||||
else:
|
||||
if "value" in ele:
|
||||
ele.pop("value")
|
||||
@ -121,20 +119,13 @@ def create_agent_session(tenant_id, agent_id):
|
||||
for ans in canvas.run(stream=False):
|
||||
pass
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl
|
||||
}
|
||||
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
API4ConversationService.save(**conv)
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
@ -156,14 +147,14 @@ def update(tenant_id, chat_id, session_id):
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def chat_completion(tenant_id, chat_id):
|
||||
req = request.json
|
||||
if not req:
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
req["question"]=""
|
||||
req["question"] = ""
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
if req.get("session_id"):
|
||||
@ -185,7 +176,7 @@ def chat_completion(tenant_id, chat_id):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route('/chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@ -260,35 +251,60 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
answer_cache = ""
|
||||
reasoning_cache = ""
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"finish_reason": None,
|
||||
"index": 0,
|
||||
"logprobs": None
|
||||
}
|
||||
],
|
||||
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
|
||||
"created": int(time.time()),
|
||||
"model": "model",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "",
|
||||
"usage": None
|
||||
"usage": None,
|
||||
}
|
||||
|
||||
try:
|
||||
for ans in chat(dia, msg, True):
|
||||
answer = ans["answer"]
|
||||
incremental = answer.replace(answer_cache, "", 1)
|
||||
answer_cache = answer.rstrip("</think>")
|
||||
token_used += len(incremental)
|
||||
response["choices"][0]["delta"]["content"] = incremental
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
|
||||
reasoning_incremental = ""
|
||||
if reasoning_part:
|
||||
if reasoning_part.startswith(reasoning_cache):
|
||||
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
|
||||
else:
|
||||
reasoning_incremental = reasoning_part
|
||||
reasoning_cache = reasoning_part
|
||||
|
||||
content_incremental = ""
|
||||
if content_part:
|
||||
if content_part.startswith(answer_cache):
|
||||
content_incremental = content_part.replace(answer_cache, "", 1)
|
||||
else:
|
||||
content_incremental = content_part
|
||||
answer_cache = content_part
|
||||
|
||||
token_used += len(reasoning_incremental) + len(content_incremental)
|
||||
|
||||
if not any([reasoning_incremental, content_incremental]):
|
||||
continue
|
||||
|
||||
if reasoning_incremental:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
|
||||
else:
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
|
||||
if content_incremental:
|
||||
response["choices"][0]["delta"]["content"] = content_incremental
|
||||
else:
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
||||
@ -296,16 +312,12 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
|
||||
# The last chunk
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {
|
||||
"prompt_tokens": len(prompt),
|
||||
"completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used
|
||||
}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
|
||||
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
@ -320,7 +332,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
break
|
||||
content = answer["answer"]
|
||||
|
||||
response = {
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
@ -332,25 +344,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": context_token_used,
|
||||
"accepted_prediction_tokens": len(content),
|
||||
"rejected_prediction_tokens": 0 # 0 for simplicity
|
||||
}
|
||||
"rejected_prediction_tokens": 0, # 0 for simplicity
|
||||
},
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
@ -361,8 +363,8 @@ def agent_completions(tenant_id, agent_id):
|
||||
dsl = cvs[0].dsl
|
||||
if not isinstance(dsl, str):
|
||||
dsl = json.dumps(dsl)
|
||||
#canvas = Canvas(dsl, tenant_id)
|
||||
#if canvas.get_preset_param():
|
||||
# canvas = Canvas(dsl, tenant_id)
|
||||
# if canvas.get_preset_param():
|
||||
# req["question"] = ""
|
||||
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
|
||||
if not conv:
|
||||
@ -376,9 +378,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
states = {field: current_dsl.get(field, []) for field in state_fields}
|
||||
current_dsl.update(new_dsl)
|
||||
current_dsl.update(states)
|
||||
API4ConversationService.update_by_id(req["session_id"], {
|
||||
"dsl": current_dsl
|
||||
})
|
||||
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
|
||||
else:
|
||||
req["question"] = ""
|
||||
if req.get("stream", True):
|
||||
@ -395,7 +395,7 @@ def agent_completions(tenant_id, agent_id):
|
||||
return get_error_data_result(str(e))
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_session(tenant_id, chat_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
@ -414,7 +414,7 @@ def list_session(tenant_id, chat_id):
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
infos = conv["messages"]
|
||||
for info in infos:
|
||||
if "prompt" in info:
|
||||
@ -448,7 +448,7 @@ def list_session(tenant_id, chat_id):
|
||||
return get_result(data=convs)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
@ -464,12 +464,11 @@ def list_agent_session(tenant_id, agent_id):
|
||||
desc = True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||
user_id, include_dsl)
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
conv['messages'] = conv.pop("message")
|
||||
conv["messages"] = conv.pop("message")
|
||||
infos = conv["messages"]
|
||||
for info in infos:
|
||||
if "prompt" in info:
|
||||
@ -502,7 +501,7 @@ def list_agent_session(tenant_id, agent_id):
|
||||
return get_result(data=convs)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
@ -528,14 +527,14 @@ def delete(tenant_id, chat_id):
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
|
||||
convs = API4ConversationService.query(dialog_id=agent_id)
|
||||
if not convs:
|
||||
return get_error_data_result(f"Agent {agent_id} has no sessions")
|
||||
@ -551,16 +550,16 @@ def delete_agent_session(tenant_id, agent_id):
|
||||
conv_list.append(conv.id)
|
||||
else:
|
||||
conv_list = ids
|
||||
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
|
||||
API4ConversationService.delete_by_id(session_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
|
||||
|
||||
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
req = request.json
|
||||
@ -586,9 +585,7 @@ def ask_about(tenant_id):
|
||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
@ -599,7 +596,7 @@ def ask_about(tenant_id):
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def related_questions(tenant_id):
|
||||
req = request.json
|
||||
@ -631,18 +628,27 @@ Reason:
|
||||
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
||||
|
||||
"""
|
||||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
|
||||
ans = chat_mdl.chat(
|
||||
prompt,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Keywords: {question}
|
||||
Related search terms:
|
||||
"""}], {"temperature": 0.9})
|
||||
""",
|
||||
}
|
||||
],
|
||||
{"temperature": 0.9},
|
||||
)
|
||||
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
||||
|
||||
|
||||
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def chatbot_completions(dialog_id):
|
||||
req = request.json
|
||||
|
||||
token = request.headers.get('Authorization').split()
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
@ -665,11 +671,11 @@ def chatbot_completions(dialog_id):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821
|
||||
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
def agent_bot_completions(agent_id):
|
||||
req = request.json
|
||||
|
||||
token = request.headers.get('Authorization').split()
|
||||
token = request.headers.get("Authorization").split()
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!"')
|
||||
token = token[1]
|
||||
|
@ -324,15 +324,18 @@ class LLMBundle:
|
||||
if self.langfuse:
|
||||
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||
|
||||
output = ""
|
||||
ans = ""
|
||||
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
||||
if isinstance(txt, int):
|
||||
if self.langfuse:
|
||||
generation.end(output={"output": output})
|
||||
generation.end(output={"output": ans})
|
||||
|
||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
|
||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
|
||||
return
|
||||
return ans
|
||||
|
||||
output = txt
|
||||
yield txt
|
||||
if txt.endswith("</think>"):
|
||||
ans = ans.rstrip("</think>")
|
||||
|
||||
ans += txt
|
||||
yield ans
|
||||
|
@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,25 +13,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from dashscope import Generation
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
from dashscope import Generation
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import openai
|
||||
from ollama import Client
|
||||
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
# Error message constants
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
@ -53,21 +53,21 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url):
|
||||
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
|
||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
|
||||
self.model_name = model_name
|
||||
# Configure retry parameters
|
||||
self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5))
|
||||
self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0))
|
||||
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
|
||||
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
|
||||
|
||||
def _get_delay(self, attempt):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5)
|
||||
|
||||
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
|
||||
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
|
||||
return ERROR_RATE_LIMIT
|
||||
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
|
||||
@ -98,11 +98,8 @@ class Base(ABC):
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
@ -111,17 +108,17 @@ class Base(ABC):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
return ans, self.total_token_count(response)
|
||||
except Exception as e:
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
||||
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay(attempt)
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})")
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
@ -136,24 +133,23 @@ class Base(ABC):
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf)
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
if ans.find("<think>") < 0:
|
||||
ans += "<think>"
|
||||
ans = ans.replace("</think>", "")
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
ans += resp.choices[0].delta.content
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
@ -221,7 +217,7 @@ class ModelScopeChat(Base):
|
||||
def __init__(self, key=None, model_name="", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
base_url = base_url.rstrip('/')
|
||||
base_url = base_url.rstrip("/")
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
super().__init__(key, model_name.split("___")[0], base_url)
|
||||
@ -236,8 +232,8 @@ class DeepSeekChat(Base):
|
||||
|
||||
class AzureChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
api_key = json.loads(key).get('api_key', '')
|
||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -264,16 +260,9 @@ class BaiChuanChat(Base):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
**self._format_params(gen_conf))
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
**self._format_params(gen_conf),
|
||||
)
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese([ans]):
|
||||
@ -295,23 +284,16 @@ class BaiChuanChat(Base):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
|
||||
stream=True,
|
||||
**self._format_params(gen_conf))
|
||||
**self._format_params(gen_conf),
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
ans += resp.choices[0].delta.content
|
||||
ans = resp.choices[0].delta.content
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
@ -333,6 +315,7 @@ class BaiChuanChat(Base):
|
||||
class QWenChat(Base):
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
@ -344,22 +327,18 @@ class QWenChat(Base):
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
return super().chat(system, history, gen_conf)
|
||||
|
||||
stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
|
||||
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
|
||||
if not stream_flag:
|
||||
from http import HTTPStatus
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
response = Generation.call(
|
||||
self.model_name,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
**gen_conf
|
||||
)
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
ans += response.output.choices[0]['message']['content']
|
||||
ans += response.output.choices[0]["message"]["content"]
|
||||
tk_count += self.total_token_count(response)
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese([ans]):
|
||||
@ -378,8 +357,9 @@ class QWenChat(Base):
|
||||
else:
|
||||
return "".join(result_list[:-1]), result_list[-1]
|
||||
|
||||
def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
|
||||
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
|
||||
from http import HTTPStatus
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -387,17 +367,10 @@ class QWenChat(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = Generation.call(
|
||||
self.model_name,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
stream=True,
|
||||
incremental_output=incremental_output,
|
||||
**gen_conf
|
||||
)
|
||||
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
ans = resp.output.choices[0]['message']['content']
|
||||
ans = resp.output.choices[0]["message"]["content"]
|
||||
tk_count = self.total_token_count(resp)
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
if is_chinese(ans):
|
||||
@ -406,8 +379,11 @@ class QWenChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans
|
||||
else:
|
||||
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)",
|
||||
str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
yield (
|
||||
ans + "\n**ERROR**: " + resp.message
|
||||
if not re.search(r" (key|quota)", str(resp.message).lower())
|
||||
else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
)
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
@ -423,10 +399,12 @@ class QWenChat(Base):
|
||||
|
||||
@staticmethod
|
||||
def is_reasoning_model(model_name: str) -> bool:
|
||||
return any([
|
||||
model_name.lower().find("deepseek") >= 0,
|
||||
model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview',
|
||||
])
|
||||
return any(
|
||||
[
|
||||
model_name.lower().find("deepseek") >= 0,
|
||||
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
@ -444,11 +422,7 @@ class ZhipuChat(Base):
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf
|
||||
)
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -471,17 +445,12 @@ class ZhipuChat(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
@ -499,8 +468,7 @@ class ZhipuChat(Base):
|
||||
|
||||
class OllamaChat(Base):
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
|
||||
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
@ -509,9 +477,7 @@ class OllamaChat(Base):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
options = {
|
||||
"num_ctx": 32768
|
||||
}
|
||||
options = {"num_ctx": 32768}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -522,12 +488,7 @@ class OllamaChat(Base):
|
||||
options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
)
|
||||
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
|
||||
ans = response["message"]["content"].strip()
|
||||
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
except Exception as e:
|
||||
@ -551,17 +512,11 @@ class OllamaChat(Base):
|
||||
options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
)
|
||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||
ans += resp["message"]["content"]
|
||||
ans = resp["message"]["content"]
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -588,9 +543,7 @@ class LocalLLM(Base):
|
||||
def __conn(self):
|
||||
from multiprocessing.connection import Client
|
||||
|
||||
self._connection = Client(
|
||||
(self.host, self.port), authkey=b"infiniflow-token4kevinhu"
|
||||
)
|
||||
self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu")
|
||||
|
||||
def __getattr__(self, name):
|
||||
import pickle
|
||||
@ -613,17 +566,17 @@ class LocalLLM(Base):
|
||||
|
||||
def _prepare_prompt(self, system, history, gen_conf):
|
||||
from rag.svr.jina_server import Prompt
|
||||
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
return Prompt(message=history, gen_conf=gen_conf)
|
||||
|
||||
def _stream_response(self, endpoint, prompt):
|
||||
from rag.svr.jina_server import Generation
|
||||
|
||||
answer = ""
|
||||
try:
|
||||
res = self.client.stream_doc(
|
||||
on=endpoint, inputs=prompt, return_type=Generation
|
||||
)
|
||||
res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
while True:
|
||||
@ -652,24 +605,24 @@ class LocalLLM(Base):
|
||||
|
||||
|
||||
class VolcEngineChat(Base):
|
||||
def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'):
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
"""
|
||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||
model_name is for display only
|
||||
"""
|
||||
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
|
||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
||||
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class MiniMaxChat(Base):
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
@ -687,13 +640,9 @@ class MiniMaxChat(Base):
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = json.dumps(
|
||||
{"model": self.model_name, "messages": history, **gen_conf}
|
||||
)
|
||||
payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
|
||||
try:
|
||||
response = requests.request(
|
||||
"POST", url=self.base_url, headers=headers, data=payload
|
||||
)
|
||||
response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
|
||||
response = response.json()
|
||||
ans = response["choices"][0]["message"]["content"].strip()
|
||||
if response["choices"][0]["finish_reason"] == "length":
|
||||
@ -737,7 +686,7 @@ class MiniMaxChat(Base):
|
||||
text = ""
|
||||
if "choices" in resp and "delta" in resp["choices"][0]:
|
||||
text = resp["choices"][0]["delta"]["content"]
|
||||
ans += text
|
||||
ans = text
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
@ -752,9 +701,9 @@ class MiniMaxChat(Base):
|
||||
|
||||
|
||||
class MistralChat(Base):
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -765,10 +714,7 @@ class MistralChat(Base):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -788,14 +734,11 @@ class MistralChat(Base):
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat_stream(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices or not resp.choices[0].delta.content:
|
||||
continue
|
||||
ans += resp.choices[0].delta.content
|
||||
ans = resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -811,23 +754,23 @@ class MistralChat(Base):
|
||||
|
||||
|
||||
class BedrockChat(Base):
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
|
||||
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
|
||||
self.bedrock_region = json.loads(key).get('bedrock_region', '')
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
|
||||
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client('bedrock-runtime')
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
else:
|
||||
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
||||
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature"]:
|
||||
del gen_conf[k]
|
||||
@ -853,6 +796,7 @@ class BedrockChat(Base):
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature"]:
|
||||
del gen_conf[k]
|
||||
@ -860,14 +804,9 @@ class BedrockChat(Base):
|
||||
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
||||
item["content"] = [{"text": item["content"]}]
|
||||
|
||||
if self.model_name.split('.')[0] == 'ai21':
|
||||
if self.model_name.split(".")[0] == "ai21":
|
||||
try:
|
||||
response = self.client.converse(
|
||||
modelId=self.model_name,
|
||||
messages=history,
|
||||
inferenceConfig=gen_conf,
|
||||
system=[{"text": (system if system else "Answer the user's message.")}]
|
||||
)
|
||||
response = self.client.converse(modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}])
|
||||
ans = response["output"]["message"]["content"][0]["text"]
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
@ -878,16 +817,13 @@ class BedrockChat(Base):
|
||||
try:
|
||||
# Send the message to the model, using a basic inference configuration.
|
||||
streaming_response = self.client.converse_stream(
|
||||
modelId=self.model_name,
|
||||
messages=history,
|
||||
inferenceConfig=gen_conf,
|
||||
system=[{"text": (system if system else "Answer the user's message.")}]
|
||||
modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}]
|
||||
)
|
||||
|
||||
# Extract and print the streamed response text in real-time.
|
||||
for resp in streaming_response["stream"]:
|
||||
if "contentBlockDelta" in resp:
|
||||
ans += resp["contentBlockDelta"]["delta"]["text"]
|
||||
ans = resp["contentBlockDelta"]["delta"]["text"]
|
||||
yield ans
|
||||
|
||||
except (ClientError, Exception) as e:
|
||||
@ -897,13 +833,12 @@ class BedrockChat(Base):
|
||||
|
||||
|
||||
class GeminiChat(Base):
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from google.generativeai import client, GenerativeModel
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = 'models/' + model_name
|
||||
self.model_name = "models/" + model_name
|
||||
self.model = GenerativeModel(model_name=self.model_name)
|
||||
self.model._client = _client
|
||||
|
||||
@ -916,17 +851,15 @@ class GeminiChat(Base):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
for item in history:
|
||||
if 'role' in item and item['role'] == 'assistant':
|
||||
item['role'] = 'model'
|
||||
if 'role' in item and item['role'] == 'system':
|
||||
item['role'] = 'user'
|
||||
if 'content' in item:
|
||||
item['parts'] = item.pop('content')
|
||||
if "role" in item and item["role"] == "assistant":
|
||||
item["role"] = "model"
|
||||
if "role" in item and item["role"] == "system":
|
||||
item["role"] = "user"
|
||||
if "content" in item:
|
||||
item["parts"] = item.pop("content")
|
||||
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
history,
|
||||
generation_config=gen_conf)
|
||||
response = self.model.generate_content(history, generation_config=gen_conf)
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
except Exception as e:
|
||||
@ -941,17 +874,15 @@ class GeminiChat(Base):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
for item in history:
|
||||
if 'role' in item and item['role'] == 'assistant':
|
||||
item['role'] = 'model'
|
||||
if 'content' in item:
|
||||
item['parts'] = item.pop('content')
|
||||
if "role" in item and item["role"] == "assistant":
|
||||
item["role"] = "model"
|
||||
if "content" in item:
|
||||
item["parts"] = item.pop("content")
|
||||
ans = ""
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
history,
|
||||
generation_config=gen_conf, stream=True)
|
||||
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
|
||||
for resp in response:
|
||||
ans += resp.text
|
||||
ans = resp.text
|
||||
yield ans
|
||||
|
||||
yield response._chunks[-1].usage_metadata.total_token_count
|
||||
@ -962,8 +893,9 @@ class GeminiChat(Base):
|
||||
|
||||
|
||||
class GroqChat(Base):
|
||||
def __init__(self, key, model_name, base_url=''):
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
from groq import Groq
|
||||
|
||||
self.client = Groq(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -975,11 +907,7 @@ class GroqChat(Base):
|
||||
del gen_conf[k]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf
|
||||
)
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -999,16 +927,11 @@ class GroqChat(Base):
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices or not resp.choices[0].delta.content:
|
||||
continue
|
||||
ans += resp.choices[0].delta.content
|
||||
ans = resp.choices[0].delta.content
|
||||
total_tokens += 1
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
@ -1096,16 +1019,10 @@ class CoHereChat(Base):
|
||||
mes = history.pop()["message"]
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.chat(
|
||||
model=self.model_name, chat_history=history, message=mes, **gen_conf
|
||||
)
|
||||
response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf)
|
||||
ans = response.text
|
||||
if response.finish_reason == "MAX_TOKENS":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
|
||||
@ -1133,20 +1050,14 @@ class CoHereChat(Base):
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat_stream(
|
||||
model=self.model_name, chat_history=history, message=mes, **gen_conf
|
||||
)
|
||||
response = self.client.chat_stream(model=self.model_name, chat_history=history, message=mes, **gen_conf)
|
||||
for resp in response:
|
||||
if resp.event_type == "text-generation":
|
||||
ans += resp.text
|
||||
ans = resp.text
|
||||
total_tokens += num_tokens_from_string(resp.text)
|
||||
elif resp.event_type == "stream-end":
|
||||
if resp.finish_reason == "MAX_TOKENS":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
@ -1217,9 +1128,7 @@ class ReplicateChat(Base):
|
||||
del gen_conf["max_tokens"]
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||
)
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.run(
|
||||
@ -1236,9 +1145,7 @@ class ReplicateChat(Base):
|
||||
del gen_conf["max_tokens"]
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||
)
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.run(
|
||||
@ -1246,7 +1153,7 @@ class ReplicateChat(Base):
|
||||
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
||||
)
|
||||
for resp in response:
|
||||
ans += resp
|
||||
ans = resp
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
@ -1268,10 +1175,10 @@ class HunyuanChat(Base):
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
@ -1296,10 +1203,10 @@ class HunyuanChat(Base):
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
@ -1327,7 +1234,7 @@ class HunyuanChat(Base):
|
||||
resp = json.loads(resp["data"])
|
||||
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
|
||||
continue
|
||||
ans += resp["Choices"][0]["Delta"]["Content"]
|
||||
ans = resp["Choices"][0]["Delta"]["Content"]
|
||||
total_tokens += 1
|
||||
|
||||
yield ans
|
||||
@ -1339,9 +1246,7 @@ class HunyuanChat(Base):
|
||||
|
||||
|
||||
class SparkChat(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
|
||||
):
|
||||
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://spark-api-open.xf-yun.com/v1"
|
||||
model2version = {
|
||||
@ -1374,22 +1279,14 @@ class BaiduYiyanChat(Base):
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
||||
0)) / 2
|
||||
) + 1
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
**gen_conf
|
||||
).body
|
||||
ans = response['result']
|
||||
response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body
|
||||
ans = response["result"]
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
except Exception as e:
|
||||
@ -1398,26 +1295,17 @@ class BaiduYiyanChat(Base):
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
||||
0)) / 2
|
||||
) + 1
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
ans += resp['result']
|
||||
ans = resp["result"]
|
||||
total_tokens = self.total_token_count(resp)
|
||||
|
||||
yield ans
|
||||
@ -1458,11 +1346,7 @@ class AnthropicChat(Base):
|
||||
).to_dict()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
@ -1483,6 +1367,7 @@ class AnthropicChat(Base):
|
||||
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
@ -1492,15 +1377,17 @@ class AnthropicChat(Base):
|
||||
**gen_conf,
|
||||
)
|
||||
for res in response:
|
||||
if res.type == 'content_block_delta':
|
||||
if res.type == "content_block_delta":
|
||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||
if ans.find("<think>") < 0:
|
||||
ans += "<think>"
|
||||
ans = ans.replace("</think>", "")
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += res.delta.thinking + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
text = res.delta.text
|
||||
ans += text
|
||||
ans = text
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
yield ans
|
||||
except Exception as e:
|
||||
@ -1511,13 +1398,12 @@ class AnthropicChat(Base):
|
||||
|
||||
class GoogleChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from google.oauth2 import service_account
|
||||
import base64
|
||||
|
||||
from google.oauth2 import service_account
|
||||
|
||||
key = json.loads(key)
|
||||
access_token = json.loads(
|
||||
base64.b64decode(key.get("google_service_account_key", ""))
|
||||
)
|
||||
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
|
||||
project_id = key.get("google_project_id", "")
|
||||
region = key.get("google_region", "")
|
||||
|
||||
@ -1530,28 +1416,20 @@ class GoogleChat(Base):
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(
|
||||
access_token, scopes=scopes
|
||||
)
|
||||
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
|
||||
request = Request()
|
||||
credits.refresh(request)
|
||||
token = credits.token
|
||||
self.client = AnthropicVertex(
|
||||
region=region, project_id=project_id, access_token=token
|
||||
)
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
|
||||
else:
|
||||
self.client = AnthropicVertex(region=region, project_id=project_id)
|
||||
else:
|
||||
from google.cloud import aiplatform
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
|
||||
if access_token:
|
||||
credits = service_account.Credentials.from_service_account_info(
|
||||
access_token
|
||||
)
|
||||
aiplatform.init(
|
||||
credentials=credits, project=project_id, location=region
|
||||
)
|
||||
credits = service_account.Credentials.from_service_account_info(access_token)
|
||||
aiplatform.init(credentials=credits, project=project_id, location=region)
|
||||
else:
|
||||
aiplatform.init(project=project_id, location=region)
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
@ -1573,15 +1451,10 @@ class GoogleChat(Base):
|
||||
).json()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"]
|
||||
+ response["usage"]["output_tokens"],
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return "\n**ERROR**: " + str(e), 0
|
||||
@ -1598,9 +1471,7 @@ class GoogleChat(Base):
|
||||
if "content" in item:
|
||||
item["parts"] = item.pop("content")
|
||||
try:
|
||||
response = self.client.generate_content(
|
||||
history, generation_config=gen_conf
|
||||
)
|
||||
response = self.client.generate_content(history, generation_config=gen_conf)
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
except Exception as e:
|
||||
@ -1627,7 +1498,7 @@ class GoogleChat(Base):
|
||||
res = res.decode("utf-8")
|
||||
if "content_block_delta" in res and "data" in res:
|
||||
text = json.loads(res[6:])["delta"]["text"]
|
||||
ans += text
|
||||
ans = text
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
@ -1647,11 +1518,9 @@ class GoogleChat(Base):
|
||||
item["parts"] = item.pop("content")
|
||||
ans = ""
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
history, generation_config=gen_conf, stream=True
|
||||
)
|
||||
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
|
||||
for resp in response:
|
||||
ans += resp.text
|
||||
ans = resp.text
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user