Refa: change LLM chat output from full to delta (incremental) (#6534)

### What problem does this PR solve?

Change LLM chat output from full to delta (incremental)

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei 2025-03-26 19:33:14 +08:00 committed by GitHub
parent 6599db1e99
commit df3890827d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 277 additions and 399 deletions

View File

@ -13,31 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import re
import time
from api.db import LLMType
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.dialog_service import ask, chat
from flask import Response, jsonify, request
from agent.canvas import Canvas
from api.db import StatusEnum
from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, validate_request
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request
from flask import jsonify, request, Response
@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id, chat_id):
req = request.json
@ -50,7 +48,7 @@ def create(tenant_id, chat_id):
"dialog_id": req["dialog_id"],
"name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", "")
"user_id": req.get("user_id", ""),
}
if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.")
@ -59,20 +57,20 @@ def create(tenant_id, chat_id):
if not e:
return get_error_data_result(message="Fail to create a session!")
conv = conv.to_dict()
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
conv["chat_id"] = conv.pop("dialog_id")
del conv["reference"]
return get_result(data=conv)
@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create_agent_session(tenant_id, agent_id):
req = request.json
if not request.is_json:
req = request.form
files = request.files
user_id = request.args.get('user_id', '')
user_id = request.args.get("user_id", "")
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
@ -113,7 +111,7 @@ def create_agent_session(tenant_id, agent_id):
ele.pop("value")
else:
if req is not None and req.get(ele["key"]):
ele["value"] = req[ele['key']]
ele["value"] = req[ele["key"]]
else:
if "value" in ele:
ele.pop("value")
@ -121,20 +119,13 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False):
pass
cvs.dsl = json.loads(str(canvas))
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl
}
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id, session_id):
req = request.json
@ -156,14 +147,14 @@ def update(tenant_id, chat_id, session_id):
return get_result()
@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
if not req:
req = {"question": ""}
if not req.get("session_id"):
req["question"]=""
req["question"] = ""
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"):
@ -185,7 +176,7 @@ def chat_completion(tenant_id, chat_id):
return get_result(data=answer)
@manager.route('/chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def chat_completion_openai_like(tenant_id, chat_id):
@ -260,35 +251,60 @@ def chat_completion_openai_like(tenant_id, chat_id):
def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None
},
"finish_reason": None,
"index": 0,
"logprobs": None
}
],
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
"created": int(time.time()),
"model": "model",
"object": "chat.completion.chunk",
"system_fingerprint": "",
"usage": None
"usage": None,
}
try:
for ans in chat(dia, msg, True):
answer = ans["answer"]
incremental = answer.replace(answer_cache, "", 1)
answer_cache = answer.rstrip("</think>")
token_used += len(incremental)
response["choices"][0]["delta"]["content"] = incremental
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer
reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part
content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part
token_used += len(reasoning_incremental) + len(content_incremental)
if not any([reasoning_incremental, content_incremental]):
continue
if reasoning_incremental:
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
else:
response["choices"][0]["delta"]["reasoning_content"] = None
if content_incremental:
response["choices"][0]["delta"]["content"] = content_incremental
else:
response["choices"][0]["delta"]["content"] = None
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@ -296,16 +312,12 @@ def chat_completion_openai_like(tenant_id, chat_id):
# The last chunk
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {
"prompt_tokens": len(prompt),
"completion_tokens": token_used,
"total_tokens": len(prompt) + token_used
}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"
resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
@ -320,7 +332,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
break
content = answer["answer"]
response = {
response = {
"id": f"chatcmpl-{chat_id}",
"object": "chat.completion",
"created": int(time.time()),
@ -332,25 +344,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
"completion_tokens_details": {
"reasoning_tokens": context_token_used,
"accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0 # 0 for simplicity
}
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
"choices": [
{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
}
return jsonify(response)
@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
@ -361,8 +363,8 @@ def agent_completions(tenant_id, agent_id):
dsl = cvs[0].dsl
if not isinstance(dsl, str):
dsl = json.dumps(dsl)
#canvas = Canvas(dsl, tenant_id)
#if canvas.get_preset_param():
# canvas = Canvas(dsl, tenant_id)
# if canvas.get_preset_param():
# req["question"] = ""
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
if not conv:
@ -376,9 +378,7 @@ def agent_completions(tenant_id, agent_id):
states = {field: current_dsl.get(field, []) for field in state_fields}
current_dsl.update(new_dsl)
current_dsl.update(states)
API4ConversationService.update_by_id(req["session_id"], {
"dsl": current_dsl
})
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
else:
req["question"] = ""
if req.get("stream", True):
@ -395,7 +395,7 @@ def agent_completions(tenant_id, agent_id):
return get_error_data_result(str(e))
@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
@ -414,7 +414,7 @@ def list_session(tenant_id, chat_id):
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@ -448,7 +448,7 @@ def list_session(tenant_id, chat_id):
return get_result(data=convs)
@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
@ -464,12 +464,11 @@ def list_agent_session(tenant_id, agent_id):
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
user_id, include_dsl)
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@ -502,7 +501,7 @@ def list_agent_session(tenant_id, agent_id):
return get_result(data=convs)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@ -528,14 +527,14 @@ def delete(tenant_id, chat_id):
return get_result()
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete_agent_session(tenant_id, agent_id):
req = request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
convs = API4ConversationService.query(dialog_id=agent_id)
if not convs:
return get_error_data_result(f"Agent {agent_id} has no sessions")
@ -551,16 +550,16 @@ def delete_agent_session(tenant_id, agent_id):
conv_list.append(conv.id)
else:
conv_list = ids
for session_id in conv_list:
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
if not conv:
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
API4ConversationService.delete_by_id(session_id)
return get_result()
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
def ask_about(tenant_id):
req = request.json
@ -586,9 +585,7 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
@ -599,7 +596,7 @@ def ask_about(tenant_id):
return resp
@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
def related_questions(tenant_id):
req = request.json
@ -631,18 +628,27 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question}
Related search terms:
"""}], {"temperature": 0.9})
""",
}
],
{"temperature": 0.9},
)
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id):
req = request.json
token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]
@ -665,11 +671,11 @@ def chatbot_completions(dialog_id):
return get_result(data=answer)
@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id):
req = request.json
token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]

View File

@ -324,15 +324,18 @@ class LLMBundle:
if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
output = ""
ans = ""
for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if self.langfuse:
generation.end(output={"output": output})
generation.end(output={"output": ans})
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return
return ans
output = txt
yield txt
if txt.endswith("</think>"):
ans = ans.rstrip("</think>")
ans += txt
yield ans

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,25 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import asyncio
import json
import logging
import os
import random
import re
import time
from abc import ABC
import openai
import requests
from dashscope import Generation
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from dashscope import Generation
from abc import ABC
from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string
import os
import json
import requests
import asyncio
import logging
import time
# Error message constants
ERROR_PREFIX = "**ERROR**"
@ -53,21 +53,21 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC):
def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5))
self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0))
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
def _get_delay(self, attempt):
"""Calculate retry delay time"""
return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5)
return self.base_delay * (2**attempt) + random.uniform(0, 0.5)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
return ERROR_RATE_LIMIT
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
@ -98,11 +98,8 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
@ -111,17 +108,17 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
return ans, self.total_token_count(response)
except Exception as e:
# Classify the error
error_code = self._classify_error(e)
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})")
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
@ -136,24 +133,23 @@ class Base(ABC):
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
if ans.find("<think>") < 0:
ans += "<think>"
ans = ans.replace("</think>", "")
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
ans += resp.choices[0].delta.content
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
@ -221,7 +217,7 @@ class ModelScopeChat(Base):
def __init__(self, key=None, model_name="", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = base_url.rstrip('/')
base_url = base_url.rstrip("/")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url)
@ -236,8 +232,8 @@ class DeepSeekChat(Base):
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
@ -264,16 +260,9 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
**self._format_params(gen_conf))
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
**self._format_params(gen_conf),
)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
@ -295,23 +284,16 @@ class BaiChuanChat(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
stream=True,
**self._format_params(gen_conf))
**self._format_params(gen_conf),
)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
@ -333,6 +315,7 @@ class BaiChuanChat(Base):
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
if self.is_reasoning_model(self.model_name):
@ -344,22 +327,18 @@ class QWenChat(Base):
if self.is_reasoning_model(self.model_name):
return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call(
self.model_name,
messages=history,
result_format='message',
**gen_conf
)
response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]['message']['content']
ans += response.output.choices[0]["message"]["content"]
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
@ -378,8 +357,9 @@ class QWenChat(Base):
else:
return "".join(result_list[:-1]), result_list[-1]
def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -387,17 +367,10 @@ class QWenChat(Base):
ans = ""
tk_count = 0
try:
response = Generation.call(
self.model_name,
messages=history,
result_format='message',
stream=True,
incremental_output=incremental_output,
**gen_conf
)
response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf)
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]['message']['content']
ans = resp.output.choices[0]["message"]["content"]
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
@ -406,8 +379,11 @@ class QWenChat(Base):
ans += LENGTH_NOTIFICATION_EN
yield ans
else:
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)",
str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
yield (
ans + "\n**ERROR**: " + resp.message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -423,10 +399,12 @@ class QWenChat(Base):
@staticmethod
def is_reasoning_model(model_name: str) -> bool:
return any([
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview',
])
return any(
[
model_name.lower().find("deepseek") >= 0,
model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview",
]
)
class ZhipuChat(Base):
@ -444,11 +422,7 @@ class ZhipuChat(Base):
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -471,17 +445,12 @@ class ZhipuChat(Base):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
ans = delta
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
@ -499,8 +468,7 @@ class ZhipuChat(Base):
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name
def chat(self, system, history, gen_conf):
@ -509,9 +477,7 @@ class OllamaChat(Base):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
options = {
"num_ctx": 32768
}
options = {"num_ctx": 32768}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
@ -522,12 +488,7 @@ class OllamaChat(Base):
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
options=options,
keep_alive=-1
)
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)
ans = response["message"]["content"].strip()
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
except Exception as e:
@ -551,17 +512,11 @@ class OllamaChat(Base):
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=history,
stream=True,
options=options,
keep_alive=-1
)
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans += resp["message"]["content"]
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -588,9 +543,7 @@ class LocalLLM(Base):
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client(
(self.host, self.port), authkey=b"infiniflow-token4kevinhu"
)
self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu")
def __getattr__(self, name):
import pickle
@ -613,17 +566,17 @@ class LocalLLM(Base):
def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt
if system:
history.insert(0, {"role": "system", "content": system})
return Prompt(message=history, gen_conf=gen_conf)
def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Generation
answer = ""
try:
res = self.client.stream_doc(
on=endpoint, inputs=prompt, return_type=Generation
)
res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation)
loop = asyncio.get_event_loop()
try:
while True:
@ -652,24 +605,24 @@ class LocalLLM(Base):
class VolcEngineChat(Base):
def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'):
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
model_name is for display only
"""
base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class MiniMaxChat(Base):
def __init__(
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
self,
key,
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
):
if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
@ -687,13 +640,9 @@ class MiniMaxChat(Base):
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = json.dumps(
{"model": self.model_name, "messages": history, **gen_conf}
)
payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
try:
response = requests.request(
"POST", url=self.base_url, headers=headers, data=payload
)
response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
response = response.json()
ans = response["choices"][0]["message"]["content"].strip()
if response["choices"][0]["finish_reason"] == "length":
@ -737,7 +686,7 @@ class MiniMaxChat(Base):
text = ""
if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"]
ans += text
ans = text
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(text)
@ -752,9 +701,9 @@ class MiniMaxChat(Base):
class MistralChat(Base):
def __init__(self, key, model_name, base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
@ -765,10 +714,7 @@ class MistralChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
try:
response = self.client.chat(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -788,14 +734,11 @@ class MistralChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name,
messages=history,
**gen_conf)
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -811,23 +754,23 @@ class MistralChat(Base):
class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def chat(self, system, history, gen_conf):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
if k not in ["temperature"]:
del gen_conf[k]
@ -853,6 +796,7 @@ class BedrockChat(Base):
def chat_streamly(self, system, history, gen_conf):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
if k not in ["temperature"]:
del gen_conf[k]
@ -860,14 +804,9 @@ class BedrockChat(Base):
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text": item["content"]}]
if self.model_name.split('.')[0] == 'ai21':
if self.model_name.split(".")[0] == "ai21":
try:
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
)
response = self.client.converse(modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}])
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
@ -878,16 +817,13 @@ class BedrockChat(Base):
try:
# Send the message to the model, using a basic inference configuration.
streaming_response = self.client.converse_stream(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}]
modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}]
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response["stream"]:
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
ans = resp["contentBlockDelta"]["delta"]["text"]
yield ans
except (ClientError, Exception) as e:
@ -897,13 +833,12 @@ class BedrockChat(Base):
class GeminiChat(Base):
def __init__(self, key, model_name, base_url=None):
from google.generativeai import client, GenerativeModel
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = 'models/' + model_name
self.model_name = "models/" + model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
@ -916,17 +851,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'role' in item and item['role'] == 'system':
item['role'] = 'user'
if 'content' in item:
item['parts'] = item.pop('content')
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "role" in item and item["role"] == "system":
item["role"] = "user"
if "content" in item:
item["parts"] = item.pop("content")
try:
response = self.model.generate_content(
history,
generation_config=gen_conf)
response = self.model.generate_content(history, generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
@ -941,17 +874,15 @@ class GeminiChat(Base):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
for item in history:
if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model'
if 'content' in item:
item['parts'] = item.pop('content')
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = item.pop("content")
ans = ""
try:
response = self.model.generate_content(
history,
generation_config=gen_conf, stream=True)
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
ans = resp.text
yield ans
yield response._chunks[-1].usage_metadata.total_token_count
@ -962,8 +893,9 @@ class GeminiChat(Base):
class GroqChat(Base):
def __init__(self, key, model_name, base_url=''):
def __init__(self, key, model_name, base_url=""):
from groq import Groq
self.client = Groq(api_key=key)
self.model_name = model_name
@ -975,11 +907,7 @@ class GroqChat(Base):
del gen_conf[k]
ans = ""
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -999,16 +927,11 @@ class GroqChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf
)
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
ans = resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@ -1096,16 +1019,10 @@ class CoHereChat(Base):
mes = history.pop()["message"]
ans = ""
try:
response = self.client.chat(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf)
ans = response.text
if response.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
@ -1133,20 +1050,14 @@ class CoHereChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name, chat_history=history, message=mes, **gen_conf
)
response = self.client.chat_stream(model=self.model_name, chat_history=history, message=mes, **gen_conf)
for resp in response:
if resp.event_type == "text-generation":
ans += resp.text
ans = resp.text
total_tokens += num_tokens_from_string(resp.text)
elif resp.event_type == "stream-end":
if resp.finish_reason == "MAX_TOKENS":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except Exception as e:
@ -1217,9 +1128,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"]
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
ans = ""
try:
response = self.client.run(
@ -1236,9 +1145,7 @@ class ReplicateChat(Base):
del gen_conf["max_tokens"]
if system:
self.system = system
prompt = "\n".join(
[item["role"] + ":" + item["content"] for item in history[-5:]]
)
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
ans = ""
try:
response = self.client.run(
@ -1246,7 +1153,7 @@ class ReplicateChat(Base):
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
)
for resp in response:
ans += resp
ans = resp
yield ans
except Exception as e:
@ -1268,10 +1175,10 @@ class HunyuanChat(Base):
self.client = hunyuan_client.HunyuanClient(cred, "")
def chat(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1296,10 +1203,10 @@ class HunyuanChat(Base):
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1327,7 +1234,7 @@ class HunyuanChat(Base):
resp = json.loads(resp["data"])
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
continue
ans += resp["Choices"][0]["Delta"]["Content"]
ans = resp["Choices"][0]["Delta"]["Content"]
total_tokens += 1
yield ans
@ -1339,9 +1246,7 @@ class HunyuanChat(Base):
class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
model2version = {
@ -1374,22 +1279,14 @@ class BaiduYiyanChat(Base):
def chat(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
try:
response = self.client.do(
model=self.model_name,
messages=history,
system=self.system,
**gen_conf
).body
ans = response['result']
response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body
ans = response["result"]
return ans, self.total_token_count(response)
except Exception as e:
@ -1398,26 +1295,17 @@ class BaiduYiyanChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
try:
response = self.client.do(
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf
)
response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf)
for resp in response:
resp = resp.body
ans += resp['result']
ans = resp["result"]
total_tokens = self.total_token_count(resp)
yield ans
@ -1458,11 +1346,7 @@ class AnthropicChat(Base):
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@ -1483,6 +1367,7 @@ class AnthropicChat(Base):
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.messages.create(
model=self.model_name,
@ -1492,15 +1377,17 @@ class AnthropicChat(Base):
**gen_conf,
)
for res in response:
if res.type == 'content_block_delta':
if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("<think>") < 0:
ans += "<think>"
ans = ans.replace("</think>", "")
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += res.delta.thinking + "</think>"
else:
reasoning_start = False
text = res.delta.text
ans += text
ans = text
total_tokens += num_tokens_from_string(text)
yield ans
except Exception as e:
@ -1511,13 +1398,12 @@ class AnthropicChat(Base):
class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None):
from google.oauth2 import service_account
import base64
from google.oauth2 import service_account
key = json.loads(key)
access_token = json.loads(
base64.b64decode(key.get("google_service_account_key", ""))
)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
region = key.get("google_region", "")
@ -1530,28 +1416,20 @@ class GoogleChat(Base):
from google.auth.transport.requests import Request
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token, scopes=scopes
)
credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
request = Request()
credits.refresh(request)
token = credits.token
self.client = AnthropicVertex(
region=region, project_id=project_id, access_token=token
)
self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
from google.cloud import aiplatform
import vertexai.generative_models as glm
from google.cloud import aiplatform
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token
)
aiplatform.init(
credentials=credits, project=project_id, location=region
)
credits = service_account.Credentials.from_service_account_info(access_token)
aiplatform.init(credentials=credits, project=project_id, location=region)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
@ -1573,15 +1451,10 @@ class GoogleChat(Base):
).json()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"]
+ response["usage"]["output_tokens"],
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return "\n**ERROR**: " + str(e), 0
@ -1598,9 +1471,7 @@ class GoogleChat(Base):
if "content" in item:
item["parts"] = item.pop("content")
try:
response = self.client.generate_content(
history, generation_config=gen_conf
)
response = self.client.generate_content(history, generation_config=gen_conf)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
@ -1627,7 +1498,7 @@ class GoogleChat(Base):
res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"]
ans += text
ans = text
total_tokens += num_tokens_from_string(text)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@ -1647,11 +1518,9 @@ class GoogleChat(Base):
item["parts"] = item.pop("content")
ans = ""
try:
response = self.model.generate_content(
history, generation_config=gen_conf, stream=True
)
response = self.model.generate_content(history, generation_config=gen_conf, stream=True)
for resp in response:
ans += resp.text
ans = resp.text
yield ans
except Exception as e: