diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index 44d0d8ba..16b1c8a1 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -116,7 +116,7 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False):
pass
-
+
cvs.dsl = json.loads(str(canvas))
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
@@ -243,6 +243,11 @@ def chat_completion_openai_like(tenant_id, chat_id):
msg = None
msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)]
+ # tools = get_tools()
+ # toolcall_session = SimpleFunctionCallServer()
+ tools = None
+ toolcall_session = None
+
if req.get("stream", True):
# The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request.
@@ -262,7 +267,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
}
try:
- for ans in chat(dia, msg, True):
+ for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
answer = ans["answer"]
reasoning_match = re.search(r"(.*?)", answer, flags=re.DOTALL)
@@ -325,7 +330,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
- for ans in chat(dia, msg, False):
+ for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
# focus answer content only
answer = ans
break
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index f3c05813..e70b925d 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -145,6 +145,9 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
+ toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
+ if toolcall_session and tools:
+ chat_mdl.bind_tools(toolcall_session, tools)
bind_llm_ts = timer()
@@ -338,7 +341,7 @@ def chat(dialog, messages, stream=True, **kwargs):
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
# Add a condition check to call the end method only if langfuse_tracer exists
- if langfuse_tracer and 'langfuse_generation' in locals():
+ if langfuse_tracer and "langfuse_generation" in locals():
langfuse_generation.end(output=langfuse_output)
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index d86e0bf0..ded4f7f3 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -102,6 +102,9 @@ class TenantLLMService(CommonService):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config:
model_config = model_config.to_dict()
+ llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
+ if llm:
+ model_config["is_tools"] = llm[0].is_tools
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
@@ -206,6 +209,8 @@ class LLMBundle:
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
+ self.is_tools = model_config.get("is_tools", False)
+
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
@@ -215,6 +220,11 @@ class LLMBundle:
else:
self.langfuse = None
+ def bind_tools(self, toolcall_session, tools):
+ if not self.is_tools:
+ return
+ self.mdl.bind_tools(toolcall_session, tools)
+
def encode(self, texts: list):
if self.langfuse:
generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts})
@@ -307,11 +317,31 @@ class LLMBundle:
if self.langfuse:
span.end()
+ def _remove_reasoning_content(self, txt: str) -> str:
+ first_think_start = txt.find("")
+ if first_think_start == -1:
+ return txt
+
+ last_think_end = txt.rfind("")
+ if last_think_end == -1:
+ return txt
+
+ if last_think_end < first_think_start:
+ return txt
+
+ return txt[last_think_end + len("") :]
+
def chat(self, system, history, gen_conf):
if self.langfuse:
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
- txt, used_tokens = self.mdl.chat(system, history, gen_conf)
+ chat = self.mdl.chat
+ if self.is_tools and self.mdl.is_tools:
+ chat = self.mdl.chat_with_tools
+
+ txt, used_tokens = chat(system, history, gen_conf)
+ txt = self._remove_reasoning_content(txt)
+
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
@@ -325,7 +355,12 @@ class LLMBundle:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
- for txt in self.mdl.chat_streamly(system, history, gen_conf):
+ chat_streamly = self.mdl.chat_streamly
+
+ if self.is_tools and self.mdl.is_tools:
+ chat_streamly = self.mdl.chat_streamly_with_tools
+
+ for txt in chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if self.langfuse:
generation.end(output={"output": ans})
diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py
index 593ec625..e39b9a1e 100644
--- a/api/utils/api_utils.py
+++ b/api/utils/api_utils.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import logging
import functools
import json
+import logging
import random
import time
from base64 import b64encode
@@ -27,59 +27,60 @@ from uuid import uuid1
import requests
from flask import (
- Response, jsonify, send_file, make_response,
+ Response,
+ jsonify,
+ make_response,
+ send_file,
+)
+from flask import (
request as flask_request,
)
from itsdangerous import URLSafeTimedSerializer
from werkzeug.http import HTTP_STATUS_CODES
-from api.db.db_models import APIToken
from api import settings
+from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
+from api.db.db_models import APIToken
+from api.utils import CustomJSONEncoder, get_uuid, json_dumps
-from api.utils import CustomJSONEncoder, get_uuid
-from api.utils import json_dumps
-from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
-
-requests.models.complexjson.dumps = functools.partial(
- json.dumps, cls=CustomJSONEncoder)
+requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
def request(**kwargs):
sess = requests.Session()
- stream = kwargs.pop('stream', sess.stream)
- timeout = kwargs.pop('timeout', None)
- kwargs['headers'] = {
- k.replace(
- '_',
- '-').upper(): v for k,
- v in kwargs.get(
- 'headers',
- {}).items()}
+ stream = kwargs.pop("stream", sess.stream)
+ timeout = kwargs.pop("timeout", None)
+ kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
prepped = requests.Request(**kwargs).prepare()
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000))
nonce = str(uuid1())
- signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
- timestamp.encode('ascii'),
- nonce.encode('ascii'),
- settings.HTTP_APP_KEY.encode('ascii'),
- prepped.path_url.encode('ascii'),
- prepped.body if kwargs.get('json') else b'',
- urlencode(
- sorted(
- kwargs['data'].items()),
- quote_via=quote,
- safe='-._~').encode('ascii')
- if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
- ]), 'sha1').digest()).decode('ascii')
+ signature = b64encode(
+ HMAC(
+ settings.SECRET_KEY.encode("ascii"),
+ b"\n".join(
+ [
+ timestamp.encode("ascii"),
+ nonce.encode("ascii"),
+ settings.HTTP_APP_KEY.encode("ascii"),
+ prepped.path_url.encode("ascii"),
+ prepped.body if kwargs.get("json") else b"",
+ urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
+ ]
+ ),
+ "sha1",
+ ).digest()
+ ).decode("ascii")
- prepped.headers.update({
- 'TIMESTAMP': timestamp,
- 'NONCE': nonce,
- 'APP-KEY': settings.HTTP_APP_KEY,
- 'SIGNATURE': signature,
- })
+ prepped.headers.update(
+ {
+ "TIMESTAMP": timestamp,
+ "NONCE": nonce,
+ "APP-KEY": settings.HTTP_APP_KEY,
+ "SIGNATURE": signature,
+ }
+ )
return sess.send(prepped, stream=stream, timeout=timeout)
@@ -87,7 +88,7 @@ def request(**kwargs):
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
- countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
+ countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
@@ -96,12 +97,9 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown)
-def get_data_error_result(code=settings.RetCode.DATA_ERROR,
- message='Sorry! Data missing!'):
+def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
logging.exception(Exception(message))
- result_dict = {
- "code": code,
- "message": message}
+ result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
@@ -119,23 +117,27 @@ def server_error_response(e):
except BaseException:
pass
if len(e.args) > 1:
- return get_json_result(
- code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
+ return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
- return get_json_result(code=settings.RetCode.EXCEPTION_ERROR,
- message="No chunk found, please upload file and parse it.")
+ return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def error_response(response_code, message=None):
if message is None:
- message = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
+ message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
- return Response(json.dumps({
- 'message': message,
- 'code': response_code,
- }), status=response_code, mimetype='application/json')
+ return Response(
+ json.dumps(
+ {
+ "message": message,
+ "code": response_code,
+ }
+ ),
+ status=response_code,
+ mimetype="application/json",
+ )
def validate_request(*args, **kwargs):
@@ -160,13 +162,10 @@ def validate_request(*args, **kwargs):
if no_arguments or error_arguments:
error_string = ""
if no_arguments:
- error_string += "required argument are missing: {}; ".format(
- ",".join(no_arguments))
+ error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
if error_arguments:
- error_string += "required argument values: {}".format(
- ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
- return get_json_result(
- code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
+ error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
+ return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs)
return decorated_function
@@ -180,8 +179,7 @@ def not_allowed_parameters(*params):
input_arguments = flask_request.json or flask_request.form.to_dict()
for param in params:
if param in input_arguments:
- return get_json_result(
- code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
+ return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs)
return wrapper
@@ -190,14 +188,14 @@ def not_allowed_parameters(*params):
def is_localhost(ip):
- return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
+ return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
def send_file_in_mem(data, filename):
if not isinstance(data, (str, bytes)):
data = json_dumps(data)
if isinstance(data, str):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
f = BytesIO()
f.write(data)
@@ -206,7 +204,7 @@ def send_file_in_mem(data, filename):
return send_file(f, as_attachment=True, attachment_filename=filename)
-def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
+def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data}
return jsonify(response)
@@ -214,27 +212,24 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
def apikey_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
- token = flask_request.headers.get('Authorization').split()[1]
+ token = flask_request.headers.get("Authorization").split()[1]
objs = APIToken.query(token=token)
if not objs:
- return build_error_result(
- message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
- )
- kwargs['tenant_id'] = objs[0].tenant_id
+ return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
+ kwargs["tenant_id"] = objs[0].tenant_id
return func(*args, **kwargs)
return decorated_function
-def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
+def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message}
response = jsonify(response)
response.status_code = code
return response
-def construct_response(code=settings.RetCode.SUCCESS,
- message='success', data=None, auth=None):
+def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
@@ -253,7 +248,7 @@ def construct_response(code=settings.RetCode.SUCCESS,
return response
-def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
+def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
@@ -264,7 +259,7 @@ def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'
return jsonify(response)
-def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
+def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
if data is None:
return jsonify({"code": code, "message": message})
else:
@@ -286,7 +281,7 @@ def construct_error_response(e):
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
- authorization_str = flask_request.headers.get('Authorization')
+ authorization_str = flask_request.headers.get("Authorization")
if not authorization_str:
return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list = authorization_str.split()
@@ -295,11 +290,8 @@ def token_required(func):
token = authorization_list[1]
objs = APIToken.query(token=token)
if not objs:
- return get_json_result(
- data=False, message='Authentication error: API key is invalid!',
- code=settings.RetCode.AUTHENTICATION_ERROR
- )
- kwargs['tenant_id'] = objs[0].tenant_id
+ return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
+ kwargs["tenant_id"] = objs[0].tenant_id
return func(*args, **kwargs)
return decorated_function
@@ -316,11 +308,11 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
return jsonify(response)
-def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
- ):
- result_dict = {
- "code": code,
- "message": message}
+def get_error_data_result(
+ message="Sorry! Data missing!",
+ code=settings.RetCode.DATA_ERROR,
+):
+ result_dict = {"code": code, "message": message}
response = {}
for key, value in result_dict.items():
if value is None and key != "code":
@@ -348,8 +340,7 @@ def valid_parameter(parameter, valid_values):
def dataset_readonly_fields(field_name):
- return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time",
- "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
+ return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
def get_parser_config(chunk_method, parser_config):
@@ -358,8 +349,7 @@ def get_parser_config(chunk_method, parser_config):
if not chunk_method:
chunk_method = "naive"
key_mapping = {
- "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC",
- "raptor": {"use_raptor": False}},
+ "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}},
"qa": {"raptor": {"use_raptor": False}},
"tag": None,
"resume": None,
@@ -370,10 +360,10 @@ def get_parser_config(chunk_method, parser_config):
"laws": {"raptor": {"use_raptor": False}},
"presentation": {"raptor": {"use_raptor": False}},
"one": None,
- "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
- "entity_types": ["organization", "person", "location", "event", "time"]},
+ "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "entity_types": ["organization", "person", "location", "event", "time"]},
"email": None,
- "picture": None}
+ "picture": None,
+ }
parser_config = key_mapping[chunk_method]
return parser_config
@@ -421,21 +411,23 @@ def get_data_openai(id=None,
def valid_parser_config(parser_config):
if not parser_config:
return
- scopes = set([
- "chunk_token_num",
- "delimiter",
- "raptor",
- "graphrag",
- "layout_recognize",
- "task_page_size",
- "pages",
- "html4excel",
- "auto_keywords",
- "auto_questions",
- "tag_kb_ids",
- "topn_tags",
- "filename_embd_weight"
- ])
+ scopes = set(
+ [
+ "chunk_token_num",
+ "delimiter",
+ "raptor",
+ "graphrag",
+ "layout_recognize",
+ "task_page_size",
+ "pages",
+ "html4excel",
+ "auto_keywords",
+ "auto_questions",
+ "tag_kb_ids",
+ "topn_tags",
+ "filename_embd_weight",
+ ]
+ )
for k in parser_config.keys():
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
@@ -457,7 +449,7 @@ def check_duplicate_ids(ids, id_type="item"):
"""
Check for duplicate IDs in a list and return unique IDs and error messages.
- Args:
+ Args:
ids (list): List of IDs to check for duplicates
id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
@@ -468,17 +460,15 @@ def check_duplicate_ids(ids, id_type="item"):
"""
id_count = {}
duplicate_messages = []
-
+
# Count occurrences of each ID
for id_value in ids:
id_count[id_value] = id_count.get(id_value, 0) + 1
-
+
# Check for duplicates
for id_value, count in id_count.items():
if count > 1:
duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
-
+
# Return unique IDs and error messages
return list(set(ids)), duplicate_messages
-
-
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 8030c945..1cc94ab1 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -59,6 +59,7 @@ class Base(ABC):
# Configure retry parameters
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
+ self.is_tools = False
def _get_delay(self, attempt):
"""Calculate retry delay time"""
@@ -89,6 +90,91 @@ class Base(ABC):
else:
return ERROR_GENERIC
+ def bind_tools(self, toolcall_session, tools):
+ if not (toolcall_session and tools):
+ return
+ self.is_tools = True
+ self.toolcall_session = toolcall_session
+ self.tools = tools
+
+ def chat_with_tools(self, system: str, history: list, gen_conf: dict):
+ if "max_tokens" in gen_conf:
+ del gen_conf["max_tokens"]
+
+ tools = self.tools
+
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+
+ ans = ""
+ tk_count = 0
+ # Implement exponential backoff retry strategy
+ for attempt in range(self.max_retries):
+ try:
+ response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
+
+ assistant_output = response.choices[0].message
+ if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
+ ans += "" + ans + ""
+ ans += response.choices[0].message.content
+
+ if not response.choices[0].message.tool_calls:
+ tk_count += self.total_token_count(response)
+ if response.choices[0].finish_reason == "length":
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, tk_count
+
+ tk_count += self.total_token_count(response)
+ history.append(assistant_output)
+
+ for tool_call in response.choices[0].message.tool_calls:
+ name = tool_call.function.name
+ args = json.loads(tool_call.function.arguments)
+
+ tool_response = self.toolcall_session.tool_call(name, args)
+ # if tool_response.choices[0].finish_reason == "length":
+ # if is_chinese(ans):
+ # ans += LENGTH_NOTIFICATION_CN
+ # else:
+ # ans += LENGTH_NOTIFICATION_EN
+ # return ans, tk_count + self.total_token_count(tool_response)
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+
+ final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
+ assistant_output = final_response.choices[0].message
+ if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
+ ans += "" + ans + ""
+ ans += final_response.choices[0].message.content
+ if final_response.choices[0].finish_reason == "length":
+ tk_count += self.total_token_count(response)
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, tk_count
+ return ans, tk_count
+
+ except Exception as e:
+ logging.exception("OpenAI cat_with_tools")
+ # Classify the error
+ error_code = self._classify_error(e)
+
+ # Check if it's a rate limit error or server error and not the last attempt
+ should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
+
+ if should_retry:
+ delay = self._get_delay(attempt)
+ logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
+ time.sleep(delay)
+ else:
+ # For non-rate limit errors or the last attempt, return an error message
+ if attempt == self.max_retries - 1:
+ error_code = ERROR_MAX_RETRIES
+ return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
+
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
@@ -127,6 +213,127 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0
+ def _wrap_toolcall_message(self, stream):
+ final_tool_calls = {}
+
+ for chunk in stream:
+ for tool_call in chunk.choices[0].delta.tool_calls or []:
+ index = tool_call.index
+
+ if index not in final_tool_calls:
+ final_tool_calls[index] = tool_call
+
+ final_tool_calls[index].function.arguments += tool_call.function.arguments
+
+ return final_tool_calls
+
+ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
+ if "max_tokens" in gen_conf:
+ del gen_conf["max_tokens"]
+
+ tools = self.tools
+
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+
+ ans = ""
+ total_tokens = 0
+ reasoning_start = False
+ finish_completion = False
+ final_tool_calls = {}
+ try:
+ response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
+ while not finish_completion:
+ for resp in response:
+ if resp.choices[0].delta.tool_calls:
+ for tool_call in resp.choices[0].delta.tool_calls or []:
+ index = tool_call.index
+
+ if index not in final_tool_calls:
+ final_tool_calls[index] = tool_call
+
+ final_tool_calls[index].function.arguments += tool_call.function.arguments
+ if resp.choices[0].finish_reason != "stop":
+ continue
+ else:
+ if not resp.choices:
+ continue
+ if not resp.choices[0].delta.content:
+ resp.choices[0].delta.content = ""
+ if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
+ ans = ""
+ if not reasoning_start:
+ reasoning_start = True
+ ans = ""
+ ans += resp.choices[0].delta.reasoning_content + ""
+ else:
+ reasoning_start = False
+ ans = resp.choices[0].delta.content
+
+ tol = self.total_token_count(resp)
+ if not tol:
+ total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
+ else:
+ total_tokens += tol
+
+ finish_reason = resp.choices[0].finish_reason
+ if finish_reason == "tool_calls" and final_tool_calls:
+ for tool_call in final_tool_calls.values():
+ name = tool_call.function.name
+ try:
+ if name == "get_current_weather":
+ args = json.loads('{"location":"Shanghai"}')
+ else:
+ args = json.loads(tool_call.function.arguments)
+ except Exception:
+ continue
+ # args = json.loads(tool_call.function.arguments)
+ tool_response = self.toolcall_session.tool_call(name, args)
+ history.append(
+ {
+ "role": "assistant",
+ "refusal": "",
+ "content": "",
+ "audio": "",
+ "function_call": "",
+ "tool_calls": [
+ {
+ "index": tool_call.index,
+ "id": tool_call.id,
+ "function": tool_call.function,
+ "type": "function",
+ },
+ ],
+ }
+ )
+ # if tool_response.choices[0].finish_reason == "length":
+ # if is_chinese(ans):
+ # ans += LENGTH_NOTIFICATION_CN
+ # else:
+ # ans += LENGTH_NOTIFICATION_EN
+ # return ans, total_tokens + self.total_token_count(tool_response)
+ history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
+ final_tool_calls = {}
+ response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
+ continue
+ if finish_reason == "length":
+ if is_chinese(ans):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, total_tokens + self.total_token_count(resp)
+ if finish_reason == "stop":
+ finish_completion = True
+ yield ans
+ break
+ yield ans
+ continue
+
+ except openai.APIError as e:
+ yield ans + "\n**ERROR**: " + str(e)
+
+ yield total_tokens
+
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
@@ -156,7 +363,7 @@ class Base(ABC):
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
- total_tokens = tol
+ total_tokens += tol
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@@ -180,9 +387,10 @@ class Base(ABC):
except Exception:
pass
return 0
-
+
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
+
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
@@ -207,15 +415,16 @@ class Base(ABC):
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
-
+
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
-
+
return ctx_size
+
class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url:
@@ -350,6 +559,8 @@ class BaiChuanChat(Base):
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
+ super().__init__(key, model_name, base_url=None)
+
import dashscope
dashscope.api_key = key
@@ -357,6 +568,78 @@ class QWenChat(Base):
if self.is_reasoning_model(self.model_name):
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
+ def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
+ if "max_tokens" in gen_conf:
+ del gen_conf["max_tokens"]
+ # if self.is_reasoning_model(self.model_name):
+ # return super().chat(system, history, gen_conf)
+
+ stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
+ if not stream_flag:
+ from http import HTTPStatus
+
+ tools = self.tools
+
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+
+ response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
+ ans = ""
+ tk_count = 0
+ if response.status_code == HTTPStatus.OK:
+ assistant_output = response.output.choices[0].message
+ if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
+ ans += "" + ans + ""
+ ans += response.output.choices[0].message.content
+
+ if "tool_calls" not in assistant_output:
+ tk_count += self.total_token_count(response)
+ if response.output.choices[0].get("finish_reason", "") == "length":
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, tk_count
+
+ tk_count += self.total_token_count(response)
+ history.append(assistant_output)
+
+ while "tool_calls" in assistant_output:
+ tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
+ tool_name = assistant_output.tool_calls[0]["function"]["name"]
+ if tool_name:
+ arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
+ tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
+ history.append(tool_info)
+
+ response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
+ if response.output.choices[0].get("finish_reason", "") == "length":
+ tk_count += self.total_token_count(response)
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ return ans, tk_count
+
+ tk_count += self.total_token_count(response)
+ assistant_output = response.output.choices[0].message
+ if assistant_output.content is None:
+ assistant_output.content = ""
+ history.append(response)
+ ans += assistant_output["content"]
+ return ans, tk_count
+ else:
+ return "**ERROR**: " + response.message, tk_count
+ else:
+ result_list = []
+ for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
+ result_list.append(result)
+ error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
+ if len(error_msg_list) > 0:
+ return "**ERROR**: " + "".join(error_msg_list), 0
+ else:
+ return "".join(result_list[:-1]), result_list[-1]
+
def chat(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@@ -393,6 +676,99 @@ class QWenChat(Base):
else:
return "".join(result_list[:-1]), result_list[-1]
+ def _wrap_toolcall_message(self, old_message, message):
+ if not old_message:
+ return message
+ tool_call_id = message["tool_calls"][0].get("id")
+ if tool_call_id:
+ old_message.tool_calls[0]["id"] = tool_call_id
+ function = message.tool_calls[0]["function"]
+ if function:
+ if function.get("name"):
+ old_message.tool_calls[0]["function"]["name"] = function["name"]
+ if function.get("arguments"):
+ old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
+ return old_message
+
+ def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
+ from http import HTTPStatus
+
+ if system:
+ history.insert(0, {"role": "system", "content": system})
+ if "max_tokens" in gen_conf:
+ del gen_conf["max_tokens"]
+ ans = ""
+ tk_count = 0
+ try:
+ response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
+ tool_info = {"content": "", "role": "tool"}
+ toolcall_message = None
+ tool_name = ""
+ tool_arguments = ""
+ finish_completion = False
+ reasoning_start = False
+ while not finish_completion:
+ for resp in response:
+ if resp.status_code == HTTPStatus.OK:
+ assistant_output = resp.output.choices[0].message
+ ans = resp.output.choices[0].message.content
+ if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
+ ans = resp.output.choices[0].message.reasoning_content
+ if not reasoning_start:
+ reasoning_start = True
+ ans = "" + ans
+ else:
+ ans = ans + ""
+
+ if "tool_calls" not in assistant_output:
+ reasoning_start = False
+ tk_count += self.total_token_count(resp)
+ if resp.output.choices[0].get("finish_reason", "") == "length":
+ if is_chinese([ans]):
+ ans += LENGTH_NOTIFICATION_CN
+ else:
+ ans += LENGTH_NOTIFICATION_EN
+ finish_reason = resp.output.choices[0]["finish_reason"]
+ if finish_reason == "stop":
+ finish_completion = True
+ yield ans
+ break
+ yield ans
+ continue
+
+ tk_count += self.total_token_count(resp)
+ toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
+ if "tool_calls" in assistant_output:
+ tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
+ if tool_call_finish_reason == "tool_calls":
+ try:
+ tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
+ except Exception as e:
+ logging.exception(msg="_chat_streamly_with_tool tool call error")
+ yield ans + "\n**ERROR**: " + str(e)
+ finish_completion = True
+ break
+
+ tool_name = toolcall_message.tool_calls[0]["function"]["name"]
+ history.append(toolcall_message)
+ tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
+ history.append(tool_info)
+ tool_info = {"content": "", "role": "tool"}
+ tool_name = ""
+ tool_arguments = ""
+ toolcall_message = None
+ response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
+ else:
+ yield (
+ ans + "\n**ERROR**: " + resp.output.choices[0].message
+ if not re.search(r" (key|quota)", str(resp.message).lower())
+ else "Out of credit. Please set the API key in **settings > Model providers.**"
+ )
+ except Exception as e:
+ logging.exception(msg="_chat_streamly_with_tool")
+ yield ans + "\n**ERROR**: " + str(e)
+ yield tk_count
+
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus
@@ -425,6 +801,13 @@ class QWenChat(Base):
yield tk_count
+ def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
+ if "max_tokens" in gen_conf:
+ del gen_conf["max_tokens"]
+
+ for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
+ yield txt
+
def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@@ -445,6 +828,8 @@ class QWenChat(Base):
class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
+ super().__init__(key, model_name, base_url=None)
+
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@@ -504,6 +889,8 @@ class ZhipuChat(Base):
class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):
+ super().__init__(key, model_name, base_url=None)
+
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name
@@ -515,10 +902,8 @@ class OllamaChat(Base):
try:
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
-
- options = {
- "num_ctx": ctx_size
- }
+
+ options = {"num_ctx": ctx_size}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
@@ -545,9 +930,7 @@ class OllamaChat(Base):
try:
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
- options = {
- "num_ctx": ctx_size
- }
+ options = {"num_ctx": ctx_size}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
@@ -561,7 +944,7 @@ class OllamaChat(Base):
ans = ""
try:
- response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
+ response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10)
for resp in response:
if resp["done"]:
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@@ -578,6 +961,8 @@ class OllamaChat(Base):
class LocalAIChat(Base):
def __init__(self, key, model_name, base_url):
+ super().__init__(key, model_name, base_url=None)
+
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1":
@@ -613,6 +998,8 @@ class LocalLLM(Base):
return do_rpc
def __init__(self, key, model_name):
+ super().__init__(key, model_name, base_url=None)
+
from jina import Client
self.client = Client(port=12345, protocol="grpc", asyncio=True)
@@ -659,6 +1046,8 @@ class LocalLLM(Base):
class VolcEngineChat(Base):
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
+ super().__init__(key, model_name, base_url=None)
+
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
@@ -677,6 +1066,8 @@ class MiniMaxChat(Base):
model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
):
+ super().__init__(key, model_name, base_url=None)
+
if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
self.base_url = base_url
@@ -755,6 +1146,8 @@ class MiniMaxChat(Base):
class MistralChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
@@ -808,6 +1201,8 @@ class MistralChat(Base):
class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs):
+ super().__init__(key, model_name, base_url=None)
+
import boto3
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
@@ -887,6 +1282,8 @@ class BedrockChat(Base):
class GeminiChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
@@ -947,6 +1344,8 @@ class GeminiChat(Base):
class GroqChat(Base):
def __init__(self, key, model_name, base_url=""):
+ super().__init__(key, model_name, base_url=None)
+
from groq import Groq
self.client = Groq(api_key=key)
@@ -1049,6 +1448,8 @@ class PPIOChat(Base):
class CoHereChat(Base):
def __init__(self, key, model_name, base_url=""):
+ super().__init__(key, model_name, base_url=None)
+
from cohere import Client
self.client = Client(api_key=key)
@@ -1171,6 +1572,8 @@ class YiChat(Base):
class ReplicateChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
from replicate.client import Client
self.model_name = model_name
@@ -1218,6 +1621,8 @@ class ReplicateChat(Base):
class HunyuanChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@@ -1321,6 +1726,8 @@ class SparkChat(Base):
class BaiduYiyanChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
import qianfan
key = json.loads(key)
@@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base):
class AnthropicChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
import anthropic
self.client = anthropic.Anthropic(api_key=key)
@@ -1452,6 +1861,8 @@ class AnthropicChat(Base):
class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None):
+ super().__init__(key, model_name, base_url=None)
+
import base64
from google.oauth2 import service_account