Refa: change LLM chat output from full to delta (incremental) (#6534)

### What problem does this PR solve?

Change LLM chat output from full to delta (incremental)

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei 2025-03-26 19:33:14 +08:00 committed by GitHub
parent 6599db1e99
commit df3890827d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 277 additions and 399 deletions

View File

@ -13,31 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import re
import json import json
import re
import time import time
from api.db import LLMType from flask import Response, jsonify, request
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.dialog_service import ask, chat
from agent.canvas import Canvas from agent.canvas import Canvas
from api.db import StatusEnum from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService from api.db.services.canvas_service import completion as agent_completion
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.conversation_service import ConversationService, iframe_completion
from api.utils import get_uuid from api.db.services.conversation_service import completion as rag_completion
from api.utils.api_utils import get_error_data_result, validate_request from api.db.services.dialog_service import DialogService, ask, chat
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request
from flask import jsonify, request, Response
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create(tenant_id, chat_id): def create(tenant_id, chat_id):
req = request.json req = request.json
@ -50,7 +48,7 @@ def create(tenant_id, chat_id):
"dialog_id": req["dialog_id"], "dialog_id": req["dialog_id"],
"name": req.get("name", "New session"), "name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", "") "user_id": req.get("user_id", ""),
} }
if not conv.get("name"): if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.") return get_error_data_result(message="`name` can not be empty.")
@ -59,20 +57,20 @@ def create(tenant_id, chat_id):
if not e: if not e:
return get_error_data_result(message="Fail to create a session!") return get_error_data_result(message="Fail to create a session!")
conv = conv.to_dict() conv = conv.to_dict()
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
conv["chat_id"] = conv.pop("dialog_id") conv["chat_id"] = conv.pop("dialog_id")
del conv["reference"] del conv["reference"]
return get_result(data=conv) return get_result(data=conv)
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
def create_agent_session(tenant_id, agent_id): def create_agent_session(tenant_id, agent_id):
req = request.json req = request.json
if not request.is_json: if not request.is_json:
req = request.form req = request.form
files = request.files files = request.files
user_id = request.args.get('user_id', '') user_id = request.args.get("user_id", "")
e, cvs = UserCanvasService.get_by_id(agent_id) e, cvs = UserCanvasService.get_by_id(agent_id)
if not e: if not e:
@ -113,7 +111,7 @@ def create_agent_session(tenant_id, agent_id):
ele.pop("value") ele.pop("value")
else: else:
if req is not None and req.get(ele["key"]): if req is not None and req.get(ele["key"]):
ele["value"] = req[ele['key']] ele["value"] = req[ele["key"]]
else: else:
if "value" in ele: if "value" in ele:
ele.pop("value") ele.pop("value")
@ -121,20 +119,13 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False): for ans in canvas.run(stream=False):
pass pass
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
conv = { conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl
}
API4ConversationService.save(**conv) API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id") conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv) return get_result(data=conv)
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821 @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required @token_required
def update(tenant_id, chat_id, session_id): def update(tenant_id, chat_id, session_id):
req = request.json req = request.json
@ -156,7 +147,7 @@ def update(tenant_id, chat_id, session_id):
return get_result() return get_result()
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
def chat_completion(tenant_id, chat_id): def chat_completion(tenant_id, chat_id):
req = request.json req = request.json
@ -185,7 +176,7 @@ def chat_completion(tenant_id, chat_id):
return get_result(data=answer) return get_result(data=answer)
@manager.route('/chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821 @manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821 @validate_request("model", "messages") # noqa: F821
@token_required @token_required
def chat_completion_openai_like(tenant_id, chat_id): def chat_completion_openai_like(tenant_id, chat_id):
@ -260,35 +251,60 @@ def chat_completion_openai_like(tenant_id, chat_id):
def streamed_response_generator(chat_id, dia, msg): def streamed_response_generator(chat_id, dia, msg):
token_used = 0 token_used = 0
answer_cache = "" answer_cache = ""
reasoning_cache = ""
response = { response = {
"id": f"chatcmpl-{chat_id}", "id": f"chatcmpl-{chat_id}",
"choices": [ "choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None
},
"finish_reason": None,
"index": 0,
"logprobs": None
}
],
"created": int(time.time()), "created": int(time.time()),
"model": "model", "model": "model",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "", "system_fingerprint": "",
"usage": None "usage": None,
} }
try: try:
for ans in chat(dia, msg, True): for ans in chat(dia, msg, True):
answer = ans["answer"] answer = ans["answer"]
incremental = answer.replace(answer_cache, "", 1)
answer_cache = answer.rstrip("</think>") reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
token_used += len(incremental) if reasoning_match:
response["choices"][0]["delta"]["content"] = incremental reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
continue
if reasoning_incremental:
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
else:
response["choices"][0]["delta"]["reasoning_content"] = None
if content_incremental:
response["choices"][0]["delta"]["content"] = content_incremental
else:
response["choices"][0]["delta"]["content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e: except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -296,16 +312,12 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The last chunk # The last chunk
response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop" response["choices"][0]["finish_reason"] = "stop"
response["usage"] = { response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
"prompt_tokens": len(prompt),
"completion_tokens": token_used,
"total_tokens": len(prompt) + token_used
}
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n" yield "data:[DONE]\n\n"
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("Connection", "keep-alive")
@ -332,25 +344,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
"completion_tokens_details": { "completion_tokens_details": {
"reasoning_tokens": context_token_used, "reasoning_tokens": context_token_used,
"accepted_prediction_tokens": len(content), "accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0 # 0 for simplicity "rejected_prediction_tokens": 0, # 0 for simplicity
}
}, },
"choices": [
{
"message": {
"role": "assistant",
"content": content
}, },
"logprobs": None, "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
"finish_reason": "stop",
"index": 0
}
]
} }
return jsonify(response) return jsonify(response)
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required @token_required
def agent_completions(tenant_id, agent_id): def agent_completions(tenant_id, agent_id):
req = request.json req = request.json
@ -376,9 +378,7 @@ def agent_completions(tenant_id, agent_id):
states = {field: current_dsl.get(field, []) for field in state_fields} states = {field: current_dsl.get(field, []) for field in state_fields}
current_dsl.update(new_dsl) current_dsl.update(new_dsl)
current_dsl.update(states) current_dsl.update(states)
API4ConversationService.update_by_id(req["session_id"], { API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
"dsl": current_dsl
})
else: else:
req["question"] = "" req["question"] = ""
if req.get("stream", True): if req.get("stream", True):
@ -395,7 +395,7 @@ def agent_completions(tenant_id, agent_id):
return get_error_data_result(str(e)) return get_error_data_result(str(e))
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_session(tenant_id, chat_id): def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
@ -414,7 +414,7 @@ def list_session(tenant_id, chat_id):
if not convs: if not convs:
return get_result(data=[]) return get_result(data=[])
for conv in convs: for conv in convs:
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
infos = conv["messages"] infos = conv["messages"]
for info in infos: for info in infos:
if "prompt" in info: if "prompt" in info:
@ -448,7 +448,7 @@ def list_session(tenant_id, chat_id):
return get_result(data=convs) return get_result(data=convs)
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_agent_session(tenant_id, agent_id): def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
@ -464,12 +464,11 @@ def list_agent_session(tenant_id, agent_id):
desc = True desc = True
# dsl defaults to True in all cases except for False and false # dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
user_id, include_dsl)
if not convs: if not convs:
return get_result(data=[]) return get_result(data=[])
for conv in convs: for conv in convs:
conv['messages'] = conv.pop("message") conv["messages"] = conv.pop("message")
infos = conv["messages"] infos = conv["messages"]
for info in infos: for info in infos:
if "prompt" in info: if "prompt" in info:
@ -502,7 +501,7 @@ def list_agent_session(tenant_id, agent_id):
return get_result(data=convs) return get_result(data=convs)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required @token_required
def delete(tenant_id, chat_id): def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -528,7 +527,7 @@ def delete(tenant_id, chat_id):
return get_result() return get_result()
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821 @manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required @token_required
def delete_agent_session(tenant_id, agent_id): def delete_agent_session(tenant_id, agent_id):
req = request.json req = request.json
@ -560,7 +559,7 @@ def delete_agent_session(tenant_id, agent_id):
return get_result() return get_result()
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821 @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required @token_required
def ask_about(tenant_id): def ask_about(tenant_id):
req = request.json req = request.json
@ -586,9 +585,7 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid): for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream") resp = Response(stream(), mimetype="text/event-stream")
@ -599,7 +596,7 @@ def ask_about(tenant_id):
return resp return resp
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821 @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required @token_required
def related_questions(tenant_id): def related_questions(tenant_id):
req = request.json req = request.json
@ -631,18 +628,27 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
""" """
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question} Keywords: {question}
Related search terms: Related search terms:
"""}], {"temperature": 0.9}) """,
}
],
{"temperature": 0.9},
)
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id): def chatbot_completions(dialog_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
token = token[1] token = token[1]
@ -665,11 +671,11 @@ def chatbot_completions(dialog_id):
return get_result(data=answer) return get_result(data=answer)
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821 @manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id): def agent_bot_completions(agent_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split() token = request.headers.get("Authorization").split()
if len(token) != 2: if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"') return get_error_data_result(message='Authorization is not valid!"')
token = token[1] token = token[1]

View File

@ -324,15 +324,18 @@ class LLMBundle:
if self.langfuse: if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
output = "" ans = ""
for txt in self.mdl.chat_streamly(system, history, gen_conf): for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int): if isinstance(txt, int):
if self.langfuse: if self.langfuse:
generation.end(output={"output": output}) generation.end(output={"output": ans})
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return return ans
output = txt if txt.endswith("</think>"):
yield txt ans = ans.rstrip("</think>")
ans += txt
yield ans

View File

@ -1,5 +1,5 @@
# #
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,25 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import re import asyncio
import json
import logging
import os
import random import random
import re
import time
from abc import ABC
import openai
import requests
from dashscope import Generation
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
from dashscope import Generation
from abc import ABC
from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_chinese, is_english from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
import os
import json
import requests
import asyncio
import logging
import time
# Error message constants # Error message constants
ERROR_PREFIX = "**ERROR**" ERROR_PREFIX = "**ERROR**"
@ -53,12 +53,12 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600)) timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name self.model_name = model_name
# Configure retry parameters # Configure retry parameters
self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5)) self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0)) self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
def _get_delay(self, attempt): def _get_delay(self, attempt):
"""Calculate retry delay time""" """Calculate retry delay time"""
@ -98,10 +98,7 @@ class Base(ABC):
# Implement exponential backoff retry strategy # Implement exponential backoff retry strategy
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
model=self.model_name,
messages=history,
**gen_conf)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0 return "", 0
@ -136,24 +133,23 @@ class Base(ABC):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
reasoning_start = False
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
model=self.model_name,
messages=history,
stream=True,
**gen_conf)
for resp in response: for resp in response:
if not resp.choices: if not resp.choices:
continue continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
if ans.find("<think>") < 0: ans = ""
ans += "<think>" if not reasoning_start:
ans = ans.replace("</think>", "") reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>" ans += resp.choices[0].delta.reasoning_content + "</think>"
else: else:
ans += resp.choices[0].delta.content reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp) tol = self.total_token_count(resp)
if not tol: if not tol:
@ -221,7 +217,7 @@ class ModelScopeChat(Base):
def __init__(self, key=None, model_name="", base_url=""): def __init__(self, key=None, model_name="", base_url=""):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
base_url = base_url.rstrip('/') base_url = base_url.rstrip("/")
if base_url.split("/")[-1] != "v1": if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1") base_url = os.path.join(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url) super().__init__(key, model_name.split("___")[0], base_url)
@ -236,8 +232,8 @@ class DeepSeekChat(Base):
class AzureChat(Base): class AzureChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
api_key = json.loads(key).get('api_key', '') api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get('api_version', '2024-02-01') api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name self.model_name = model_name
@ -264,16 +260,9 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
extra_body={ extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
"tools": [{ **self._format_params(gen_conf),
"type": "web_search", )
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
**self._format_params(gen_conf))
ans = response.choices[0].message.content.strip() ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length": if response.choices[0].finish_reason == "length":
if is_chinese([ans]): if is_chinese([ans]):
@ -295,23 +284,16 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
extra_body={ extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
stream=True, stream=True,
**self._format_params(gen_conf)) **self._format_params(gen_conf),
)
for resp in response: for resp in response:
if not resp.choices: if not resp.choices:
continue continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans = resp.choices[0].delta.content
tol = self.total_token_count(resp) tol = self.total_token_count(resp)
if not tol: if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content) total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
@ -333,6 +315,7 @@ class BaiChuanChat(Base):
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
if self.is_reasoning_model(self.model_name): if self.is_reasoning_model(self.model_name):
@ -344,22 +327,18 @@ class QWenChat(Base):
if self.is_reasoning_model(self.model_name): if self.is_reasoning_model(self.model_name):
return super().chat(system, history, gen_conf) return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true' stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag: if not stream_flag:
from http import HTTPStatus from http import HTTPStatus
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
response = Generation.call( response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
self.model_name,
messages=history,
result_format='message',
**gen_conf
)
ans = "" ans = ""
tk_count = 0 tk_count = 0
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]['message']['content'] ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response) tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length": if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]): if is_chinese([ans]):
@ -378,8 +357,9 @@ class QWenChat(Base):
else: else:
return "".join(result_list[:-1]), result_list[-1] return "".join(result_list[:-1]), result_list[-1]
def _chat_streamly(self, system, history, gen_conf, incremental_output=False): def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus from http import HTTPStatus
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
@ -387,17 +367,10 @@ class QWenChat(Base):
ans = "" ans = ""
tk_count = 0 tk_count = 0
try: try:
response = Generation.call( response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
self.model_name,
messages=history,
result_format='message',
stream=True,
incremental_output=incremental_output,
**gen_conf
)
for resp in response: for resp in response:
if resp.status_code == HTTPStatus.OK: if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]['message']['content'] ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp) tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length": if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans): if is_chinese(ans):
@ -406,8 +379,11 @@ class QWenChat(Base):
ans += LENGTH_NOTIFICATION_EN ans += LENGTH_NOTIFICATION_EN
yield ans yield ans
else: else:
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", yield (
str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**" ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -423,10 +399,12 @@ class QWenChat(Base):
@staticmethod @staticmethod
def is_reasoning_model(model_name: str) -> bool: def is_reasoning_model(model_name: str) -> bool:
return any([ return any(
[
model_name.lower().find("deepseek") >= 0, model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview', model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]) ]
)
class ZhipuChat(Base): class ZhipuChat(Base):
@ -444,11 +422,7 @@ class ZhipuChat(Base):
del gen_conf["presence_penalty"] del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"] del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
model=self.model_name,
messages=history,
**gen_conf
)
ans = response.choices[0].message.content.strip() ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length": if response.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -471,17 +445,12 @@ class ZhipuChat(Base):
ans = "" ans = ""
tk_count = 0 tk_count = 0
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
for resp in response: for resp in response:
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
continue continue
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans = delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN ans += LENGTH_NOTIFICATION_CN
@ -499,8 +468,7 @@ class ZhipuChat(Base):
class OllamaChat(Base): class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \ self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
@ -509,9 +477,7 @@ class OllamaChat(Base):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
try: try:
options = { options = {"num_ctx": 32768}
"num_ctx": 32768
}
if "temperature" in gen_conf: if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"] options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
@ -522,12 +488,7 @@ class OllamaChat(Base):
options["presence_penalty"] = gen_conf["presence_penalty"] options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"] options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat( response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
model=self.model_name,
messages=history,
options=options,
keep_alive=-1
)
ans = response["message"]["content"].strip() ans = response["message"]["content"].strip()
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0) return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e: except Exception as e:
@ -551,17 +512,11 @@ class OllamaChat(Base):
options["frequency_penalty"] = gen_conf["frequency_penalty"] options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = "" ans = ""
try: try:
response = self.client.chat( response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
model=self.model_name,
messages=history,
stream=True,
options=options,
keep_alive=-1
)
for resp in response: for resp in response:
if resp["done"]: if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans += resp["message"]["content"] ans = resp["message"]["content"]
yield ans yield ans
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -588,9 +543,7 @@ class LocalLLM(Base):
def __conn(self): def __conn(self):
from multiprocessing.connection import Client from multiprocessing.connection import Client
self._connection = Client( self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu")
(self.host, self.port), authkey=b"infiniflow-token4kevinhu"
)
def __getattr__(self, name): def __getattr__(self, name):
import pickle import pickle
@ -613,17 +566,17 @@ class LocalLLM(Base):
def _prepare_prompt(self, system, history, gen_conf): def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt from rag.svr.jina_server import Prompt
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
return Prompt(message=history, gen_conf=gen_conf) return Prompt(message=history, gen_conf=gen_conf)
def _stream_response(self, endpoint, prompt): def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Generation from rag.svr.jina_server import Generation
answer = "" answer = ""
try: try:
res = self.client.stream_doc( res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
on=endpoint, inputs=prompt, return_type=Generation
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
while True: while True:
@ -652,15 +605,15 @@ class LocalLLM(Base):
class VolcEngineChat(Base): class VolcEngineChat(Base):
def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
""" """
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
model_name is for display only model_name is for display only
""" """
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3' base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get('ark_api_key', '') ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url) super().__init__(ark_api_key, model_name, base_url)
@ -687,13 +640,9 @@ class MiniMaxChat(Base):
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
payload = json.dumps( payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
{"model": self.model_name, "messages": history, **gen_conf}
)
try: try:
response = requests.request( response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
"POST", url=self.base_url, headers=headers, data=payload
)
response = response.json() response = response.json()
ans = response["choices"][0]["message"]["content"].strip() ans = response["choices"][0]["message"]["content"].strip()
if response["choices"][0]["finish_reason"] == "length": if response["choices"][0]["finish_reason"] == "length":
@ -737,7 +686,7 @@ class MiniMaxChat(Base):
text = "" text = ""
if "choices" in resp and "delta" in resp["choices"][0]: if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"] text = resp["choices"][0]["delta"]["content"]
ans += text ans = text
tol = self.total_token_count(resp) tol = self.total_token_count(resp)
if not tol: if not tol:
total_tokens += num_tokens_from_string(text) total_tokens += num_tokens_from_string(text)
@ -752,9 +701,9 @@ class MiniMaxChat(Base):
class MistralChat(Base): class MistralChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from mistralai.client import MistralClient from mistralai.client import MistralClient
self.client = MistralClient(api_key=key) self.client = MistralClient(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -765,10 +714,7 @@ class MistralChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]: if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k] del gen_conf[k]
try: try:
response = self.client.chat( response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length": if response.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -788,14 +734,11 @@ class MistralChat(Base):
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
try: try:
response = self.client.chat_stream( response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
model=self.model_name,
messages=history,
**gen_conf)
for resp in response: for resp in response:
if not resp.choices or not resp.choices[0].delta.content: if not resp.choices or not resp.choices[0].delta.content:
continue continue
ans += resp.choices[0].delta.content ans = resp.choices[0].delta.content
total_tokens += 1 total_tokens += 1
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -811,23 +754,23 @@ class MistralChat(Base):
class BedrockChat(Base): class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
import boto3 import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_region = json.loads(key).get('bedrock_region', '') self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime') self.client = boto3.client("bedrock-runtime")
else: else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
for k in list(gen_conf.keys()): for k in list(gen_conf.keys()):
if k not in ["temperature"]: if k not in ["temperature"]:
del gen_conf[k] del gen_conf[k]
@ -853,6 +796,7 @@ class BedrockChat(Base):
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
for k in list(gen_conf.keys()): for k in list(gen_conf.keys()):
if k not in ["temperature"]: if k not in ["temperature"]:
del gen_conf[k] del gen_conf[k]
@ -860,14 +804,9 @@ class BedrockChat(Base):
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}] item["content"] = [{"text": item["content"]}]
if self.model_name.split('.')[0] == 'ai21': if self.model_name.split(".")[0] == "ai21":
try: try:
response = self.client.converse( response = self.client.converse(modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}])
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
)
ans = response["output"]["message"]["content"][0]["text"] ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans) return ans, num_tokens_from_string(ans)
@ -878,16 +817,13 @@ class BedrockChat(Base):
try: try:
# Send the message to the model, using a basic inference configuration. # Send the message to the model, using a basic inference configuration.
streaming_response = self.client.converse_stream( streaming_response = self.client.converse_stream(
modelId=self.model_name, modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}]
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
) )
# Extract and print the streamed response text in real-time. # Extract and print the streamed response text in real-time.
for resp in streaming_response["stream"]: for resp in streaming_response["stream"]:
if "contentBlockDelta" in resp: if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"] ans = resp["contentBlockDelta"]["delta"]["text"]
yield ans yield ans
except (ClientError, Exception) as e: except (ClientError, Exception) as e:
@ -897,13 +833,12 @@ class BedrockChat(Base):
class GeminiChat(Base): class GeminiChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from google.generativeai import client, GenerativeModel from google.generativeai import GenerativeModel, client
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
self.model_name = 'models/' + model_name self.model_name = "models/" + model_name
self.model = GenerativeModel(model_name=self.model_name) self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client self.model._client = _client
@ -916,17 +851,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]: if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k] del gen_conf[k]
for item in history: for item in history:
if 'role' in item and item['role'] == 'assistant': if "role" in item and item["role"] == "assistant":
item['role'] = 'model' item["role"] = "model"
if 'role' in item and item['role'] == 'system': if "role" in item and item["role"] == "system":
item['role'] = 'user' item["role"] = "user"
if 'content' in item: if "content" in item:
item['parts'] = item.pop('content') item["parts"] = item.pop("content")
try: try:
response = self.model.generate_content( response = self.model.generate_content(history, generation_config=gen_conf)
history,
generation_config=gen_conf)
ans = response.text ans = response.text
return ans, response.usage_metadata.total_token_count return ans, response.usage_metadata.total_token_count
except Exception as e: except Exception as e:
@ -941,17 +874,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]: if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k] del gen_conf[k]
for item in history: for item in history:
if 'role' in item and item['role'] == 'assistant': if "role" in item and item["role"] == "assistant":
item['role'] = 'model' item["role"] = "model"
if 'content' in item: if "content" in item:
item['parts'] = item.pop('content') item["parts"] = item.pop("content")
ans = "" ans = ""
try: try:
response = self.model.generate_content( response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
history,
generation_config=gen_conf, stream=True)
for resp in response: for resp in response:
ans += resp.text ans = resp.text
yield ans yield ans
yield response._chunks[-1].usage_metadata.total_token_count yield response._chunks[-1].usage_metadata.total_token_count
@ -962,8 +893,9 @@ class GeminiChat(Base):
class GroqChat(Base): class GroqChat(Base):
def __init__(self, key, model_name, base_url=''): def __init__(self, key, model_name, base_url=""):
from groq import Groq from groq import Groq
self.client = Groq(api_key=key) self.client = Groq(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -975,11 +907,7 @@ class GroqChat(Base):
del gen_conf[k] del gen_conf[k]
ans = "" ans = ""
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
model=self.model_name,
messages=history,
**gen_conf
)
ans = response.choices[0].message.content ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length": if response.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -999,16 +927,11 @@ class GroqChat(Base):
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
for resp in response: for resp in response:
if not resp.choices or not resp.choices[0].delta.content: if not resp.choices or not resp.choices[0].delta.content:
continue continue
ans += resp.choices[0].delta.content ans = resp.choices[0].delta.content
total_tokens += 1 total_tokens += 1
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -1096,16 +1019,10 @@ class CoHereChat(Base):
mes = history.pop()["message"] mes = history.pop()["message"]
ans = "" ans = ""
try: try:
response = self.client.chat( response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf)
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
ans = response.text ans = response.text
if response.finish_reason == "MAX_TOKENS": if response.finish_reason == "MAX_TOKENS":
ans += ( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return ( return (
ans, ans,
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
@ -1133,20 +1050,14 @@ class CoHereChat(Base):
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
try: try:
response = self.client.chat_stream( response = self.client.chat_stream(model=self.model_name, chat_history=history, message=mes, **gen_conf)
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
for resp in response: for resp in response:
if resp.event_type == "text-generation": if resp.event_type == "text-generation":
ans += resp.text ans = resp.text
total_tokens += num_tokens_from_string(resp.text) total_tokens += num_tokens_from_string(resp.text)
elif resp.event_type == "stream-end": elif resp.event_type == "stream-end":
if resp.finish_reason == "MAX_TOKENS": if resp.finish_reason == "MAX_TOKENS":
ans += ( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
yield ans yield ans
except Exception as e: except Exception as e:
@ -1217,9 +1128,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
if system: if system:
self.system = system self.system = system
prompt = "\n".join( prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
ans = "" ans = ""
try: try:
response = self.client.run( response = self.client.run(
@ -1236,9 +1145,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
if system: if system:
self.system = system self.system = system
prompt = "\n".join( prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
ans = "" ans = ""
try: try:
response = self.client.run( response = self.client.run(
@ -1246,7 +1153,7 @@ class ReplicateChat(Base):
input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
) )
for resp in response: for resp in response:
ans += resp ans = resp
yield ans yield ans
except Exception as e: except Exception as e:
@ -1268,10 +1175,10 @@ class HunyuanChat(Base):
self.client = hunyuan_client.HunyuanClient(cred, "") self.client = hunyuan_client.HunyuanClient(cred, "")
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException, TencentCloudSDKException,
) )
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {} _gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history] _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1296,10 +1203,10 @@ class HunyuanChat(Base):
return ans + "\n**ERROR**: " + str(e), 0 return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException, TencentCloudSDKException,
) )
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {} _gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history] _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1327,7 +1234,7 @@ class HunyuanChat(Base):
resp = json.loads(resp["data"]) resp = json.loads(resp["data"])
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]: if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
continue continue
ans += resp["Choices"][0]["Delta"]["Content"] ans = resp["Choices"][0]["Delta"]["Content"]
total_tokens += 1 total_tokens += 1
yield ans yield ans
@ -1339,9 +1246,7 @@ class HunyuanChat(Base):
class SparkChat(Base): class SparkChat(Base):
def __init__( def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"):
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
if not base_url: if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1" base_url = "https://spark-api-open.xf-yun.com/v1"
model2version = { model2version = {
@ -1374,22 +1279,14 @@ class BaiduYiyanChat(Base):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: if system:
self.system = system self.system = system
gen_conf["penalty_score"] = ( gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
ans = "" ans = ""
try: try:
response = self.client.do( response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body
model=self.model_name, ans = response["result"]
messages=history,
system=self.system,
**gen_conf
).body
ans = response['result']
return ans, self.total_token_count(response) return ans, self.total_token_count(response)
except Exception as e: except Exception as e:
@ -1398,26 +1295,17 @@ class BaiduYiyanChat(Base):
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if system: if system:
self.system = system self.system = system
gen_conf["penalty_score"] = ( gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
try: try:
response = self.client.do( response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf)
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf
)
for resp in response: for resp in response:
resp = resp.body resp = resp.body
ans += resp['result'] ans = resp["result"]
total_tokens = self.total_token_count(resp) total_tokens = self.total_token_count(resp)
yield ans yield ans
@ -1458,11 +1346,7 @@ class AnthropicChat(Base):
).to_dict() ).to_dict()
ans = response["content"][0]["text"] ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens": if response["stop_reason"] == "max_tokens":
ans += ( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return ( return (
ans, ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"], response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@ -1483,6 +1367,7 @@ class AnthropicChat(Base):
ans = "" ans = ""
total_tokens = 0 total_tokens = 0
reasoning_start = False
try: try:
response = self.client.messages.create( response = self.client.messages.create(
model=self.model_name, model=self.model_name,
@ -1492,15 +1377,17 @@ class AnthropicChat(Base):
**gen_conf, **gen_conf,
) )
for res in response: for res in response:
if res.type == 'content_block_delta': if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking: if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("<think>") < 0: ans = ""
ans += "<think>" if not reasoning_start:
ans = ans.replace("</think>", "") reasoning_start = True
ans = "<think>"
ans += res.delta.thinking + "</think>" ans += res.delta.thinking + "</think>"
else: else:
reasoning_start = False
text = res.delta.text text = res.delta.text
ans += text ans = text
total_tokens += num_tokens_from_string(text) total_tokens += num_tokens_from_string(text)
yield ans yield ans
except Exception as e: except Exception as e:
@ -1511,13 +1398,12 @@ class AnthropicChat(Base):
class GoogleChat(Base): class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from google.oauth2 import service_account
import base64 import base64
from google.oauth2 import service_account
key = json.loads(key) key = json.loads(key)
access_token = json.loads( access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
base64.b64decode(key.get("google_service_account_key", ""))
)
project_id = key.get("google_project_id", "") project_id = key.get("google_project_id", "")
region = key.get("google_region", "") region = key.get("google_region", "")
@ -1530,28 +1416,20 @@ class GoogleChat(Base):
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
if access_token: if access_token:
credits = service_account.Credentials.from_service_account_info( credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
access_token, scopes=scopes
)
request = Request() request = Request()
credits.refresh(request) credits.refresh(request)
token = credits.token token = credits.token
self.client = AnthropicVertex( self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
region=region, project_id=project_id, access_token=token
)
else: else:
self.client = AnthropicVertex(region=region, project_id=project_id) self.client = AnthropicVertex(region=region, project_id=project_id)
else: else:
from google.cloud import aiplatform
import vertexai.generative_models as glm import vertexai.generative_models as glm
from google.cloud import aiplatform
if access_token: if access_token:
credits = service_account.Credentials.from_service_account_info( credits = service_account.Credentials.from_service_account_info(access_token)
access_token aiplatform.init(credentials=credits, project=project_id, location=region)
)
aiplatform.init(
credentials=credits, project=project_id, location=region
)
else: else:
aiplatform.init(project=project_id, location=region) aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name) self.client = glm.GenerativeModel(model_name=self.model_name)
@ -1573,15 +1451,10 @@ class GoogleChat(Base):
).json() ).json()
ans = response["content"][0]["text"] ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens": if response["stop_reason"] == "max_tokens":
ans += ( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return ( return (
ans, ans,
response["usage"]["input_tokens"] response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
+ response["usage"]["output_tokens"],
) )
except Exception as e: except Exception as e:
return "\n**ERROR**: " + str(e), 0 return "\n**ERROR**: " + str(e), 0
@ -1598,9 +1471,7 @@ class GoogleChat(Base):
if "content" in item: if "content" in item:
item["parts"] = item.pop("content") item["parts"] = item.pop("content")
try: try:
response = self.client.generate_content( response = self.client.generate_content(history, generation_config=gen_conf)
history, generation_config=gen_conf
)
ans = response.text ans = response.text
return ans, response.usage_metadata.total_token_count return ans, response.usage_metadata.total_token_count
except Exception as e: except Exception as e:
@ -1627,7 +1498,7 @@ class GoogleChat(Base):
res = res.decode("utf-8") res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res: if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"] text = json.loads(res[6:])["delta"]["text"]
ans += text ans = text
total_tokens += num_tokens_from_string(text) total_tokens += num_tokens_from_string(text)
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -1647,11 +1518,9 @@ class GoogleChat(Base):
item["parts"] = item.pop("content") item["parts"] = item.pop("content")
ans = "" ans = ""
try: try:
response = self.model.generate_content( response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
history, generation_config=gen_conf, stream=True
)
for resp in response: for resp in response:
ans += resp.text ans = resp.text
yield ans yield ans
except Exception as e: except Exception as e: