Feat: add primitive support for function calls (#6840)

### What problem does this PR solve?

This PR introduces ​**​primitive support for function calls​**​,
enabling the system to handle basic function call capabilities.
However, this feature is currently experimental and ​**​not yet enabled
for general use​**​, as it is only supported by a subset of models,
namely, Qwen and OpenAI models.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei 2025-04-08 16:09:03 +08:00 committed by GitHub
parent a20439bf81
commit dc2c74b249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 574 additions and 130 deletions

View File

@ -243,6 +243,11 @@ def chat_completion_openai_like(tenant_id, chat_id):
msg = None msg = None
msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)] msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)]
# tools = get_tools()
# toolcall_session = SimpleFunctionCallServer()
tools = None
toolcall_session = None
if req.get("stream", True): if req.get("stream", True):
# The value for the usage field on all chunks except for the last one will be null. # The value for the usage field on all chunks except for the last one will be null.
# The usage field on the last chunk contains token usage statistics for the entire request. # The usage field on the last chunk contains token usage statistics for the entire request.
@ -262,7 +267,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
} }
try: try:
for ans in chat(dia, msg, True): for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
answer = ans["answer"] answer = ans["answer"]
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL) reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
@ -325,7 +330,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
return resp return resp
else: else:
answer = None answer = None
for ans in chat(dia, msg, False): for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
# focus answer content only # focus answer content only
answer = ans answer = ans
break break

View File

@ -145,6 +145,9 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else: else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)
bind_llm_ts = timer() bind_llm_ts = timer()
@ -338,7 +341,7 @@ def chat(dialog, messages, stream=True, **kwargs):
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()} langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
# Add a condition check to call the end method only if langfuse_tracer exists # Add a condition check to call the end method only if langfuse_tracer exists
if langfuse_tracer and 'langfuse_generation' in locals(): if langfuse_tracer and "langfuse_generation" in locals():
langfuse_generation.end(output=langfuse_output) langfuse_generation.end(output=langfuse_output)
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}

View File

@ -102,6 +102,9 @@ class TenantLLMService(CommonService):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: if model_config:
model_config = model_config.to_dict() model_config = model_config.to_dict()
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm:
model_config["is_tools"] = llm[0].is_tools
if not model_config: if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
@ -206,6 +209,8 @@ class LLMBundle:
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192) self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
if langfuse_keys: if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
@ -215,6 +220,11 @@ class LLMBundle:
else: else:
self.langfuse = None self.langfuse = None
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:
return
self.mdl.bind_tools(toolcall_session, tools)
def encode(self, texts: list): def encode(self, texts: list):
if self.langfuse: if self.langfuse:
generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts}) generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts})
@ -307,11 +317,31 @@ class LLMBundle:
if self.langfuse: if self.langfuse:
span.end() span.end()
def _remove_reasoning_content(self, txt: str) -> str:
first_think_start = txt.find("<think>")
if first_think_start == -1:
return txt
last_think_end = txt.rfind("</think>")
if last_think_end == -1:
return txt
if last_think_end < first_think_start:
return txt
return txt[last_think_end + len("</think>") :]
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if self.langfuse: if self.langfuse:
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history}) generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
txt, used_tokens = self.mdl.chat(system, history, gen_conf) chat = self.mdl.chat
if self.is_tools and self.mdl.is_tools:
chat = self.mdl.chat_with_tools
txt, used_tokens = chat(system, history, gen_conf)
txt = self._remove_reasoning_content(txt)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
@ -325,7 +355,12 @@ class LLMBundle:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = "" ans = ""
for txt in self.mdl.chat_streamly(system, history, gen_conf): chat_streamly = self.mdl.chat_streamly
if self.is_tools and self.mdl.is_tools:
chat_streamly = self.mdl.chat_streamly_with_tools
for txt in chat_streamly(system, history, gen_conf):
if isinstance(txt, int): if isinstance(txt, int):
if self.langfuse: if self.langfuse:
generation.end(output={"output": ans}) generation.end(output={"output": ans})

View File

@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import functools import functools
import json import json
import logging
import random import random
import time import time
from base64 import b64encode from base64 import b64encode
@ -27,59 +27,60 @@ from uuid import uuid1
import requests import requests
from flask import ( from flask import (
Response, jsonify, send_file, make_response, Response,
jsonify,
make_response,
send_file,
)
from flask import (
request as flask_request, request as flask_request,
) )
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
from api import settings from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db.db_models import APIToken
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
from api.utils import CustomJSONEncoder, get_uuid requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
from api.utils import json_dumps
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
def request(**kwargs): def request(**kwargs):
sess = requests.Session() sess = requests.Session()
stream = kwargs.pop('stream', sess.stream) stream = kwargs.pop("stream", sess.stream)
timeout = kwargs.pop('timeout', None) timeout = kwargs.pop("timeout", None)
kwargs['headers'] = { kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
k.replace(
'_',
'-').upper(): v for k,
v in kwargs.get(
'headers',
{}).items()}
prepped = requests.Request(**kwargs).prepare() prepped = requests.Request(**kwargs).prepare()
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY: if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000)) timestamp = str(round(time() * 1000))
nonce = str(uuid1()) nonce = str(uuid1())
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([ signature = b64encode(
timestamp.encode('ascii'), HMAC(
nonce.encode('ascii'), settings.SECRET_KEY.encode("ascii"),
settings.HTTP_APP_KEY.encode('ascii'), b"\n".join(
prepped.path_url.encode('ascii'), [
prepped.body if kwargs.get('json') else b'', timestamp.encode("ascii"),
urlencode( nonce.encode("ascii"),
sorted( settings.HTTP_APP_KEY.encode("ascii"),
kwargs['data'].items()), prepped.path_url.encode("ascii"),
quote_via=quote, prepped.body if kwargs.get("json") else b"",
safe='-._~').encode('ascii') urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', ]
]), 'sha1').digest()).decode('ascii') ),
"sha1",
).digest()
).decode("ascii")
prepped.headers.update({ prepped.headers.update(
'TIMESTAMP': timestamp, {
'NONCE': nonce, "TIMESTAMP": timestamp,
'APP-KEY': settings.HTTP_APP_KEY, "NONCE": nonce,
'SIGNATURE': signature, "APP-KEY": settings.HTTP_APP_KEY,
}) "SIGNATURE": signature,
}
)
return sess.send(prepped, stream=stream, timeout=timeout) return sess.send(prepped, stream=stream, timeout=timeout)
@ -87,7 +88,7 @@ def request(**kwargs):
def get_exponential_backoff_interval(retries, full_jitter=False): def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time.""" """Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0 # Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
# Full jitter according to # Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter: if full_jitter:
@ -96,12 +97,9 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown) return max(0, countdown)
def get_data_error_result(code=settings.RetCode.DATA_ERROR, def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
message='Sorry! Data missing!'):
logging.exception(Exception(message)) logging.exception(Exception(message))
result_dict = { result_dict = {"code": code, "message": message}
"code": code,
"message": message}
response = {} response = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
if value is None and key != "code": if value is None and key != "code":
@ -119,23 +117,27 @@ def server_error_response(e):
except BaseException: except BaseException:
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return get_json_result( return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
message="No chunk found, please upload file and parse it.")
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def error_response(response_code, message=None): def error_response(response_code, message=None):
if message is None: if message is None:
message = HTTP_STATUS_CODES.get(response_code, 'Unknown Error') message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
return Response(json.dumps({ return Response(
'message': message, json.dumps(
'code': response_code, {
}), status=response_code, mimetype='application/json') "message": message,
"code": response_code,
}
),
status=response_code,
mimetype="application/json",
)
def validate_request(*args, **kwargs): def validate_request(*args, **kwargs):
@ -160,13 +162,10 @@ def validate_request(*args, **kwargs):
if no_arguments or error_arguments: if no_arguments or error_arguments:
error_string = "" error_string = ""
if no_arguments: if no_arguments:
error_string += "required argument are missing: {}; ".format( error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
",".join(no_arguments))
if error_arguments: if error_arguments:
error_string += "required argument values: {}".format( error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return get_json_result(
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)
return decorated_function return decorated_function
@ -180,8 +179,7 @@ def not_allowed_parameters(*params):
input_arguments = flask_request.json or flask_request.form.to_dict() input_arguments = flask_request.json or flask_request.form.to_dict()
for param in params: for param in params:
if param in input_arguments: if param in input_arguments:
return get_json_result( return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@ -190,14 +188,14 @@ def not_allowed_parameters(*params):
def is_localhost(ip): def is_localhost(ip):
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'} return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
def send_file_in_mem(data, filename): def send_file_in_mem(data, filename):
if not isinstance(data, (str, bytes)): if not isinstance(data, (str, bytes)):
data = json_dumps(data) data = json_dumps(data)
if isinstance(data, str): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode("utf-8")
f = BytesIO() f = BytesIO()
f.write(data) f.write(data)
@ -206,7 +204,7 @@ def send_file_in_mem(data, filename):
return send_file(f, as_attachment=True, attachment_filename=filename) return send_file(f, as_attachment=True, attachment_filename=filename)
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
response = {"code": code, "message": message, "data": data} response = {"code": code, "message": message, "data": data}
return jsonify(response) return jsonify(response)
@ -214,27 +212,24 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
def apikey_required(func): def apikey_required(func):
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1] token = flask_request.headers.get("Authorization").split()[1]
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return build_error_result( return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN kwargs["tenant_id"] = objs[0].tenant_id
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated_function return decorated_function
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'): def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
response = {"code": code, "message": message} response = {"code": code, "message": message}
response = jsonify(response) response = jsonify(response)
response.status_code = code response.status_code = code
return response return response
def construct_response(code=settings.RetCode.SUCCESS, def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
message='success', data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data} result_dict = {"code": code, "message": message, "data": data}
response_dict = {} response_dict = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
@ -253,7 +248,7 @@ def construct_response(code=settings.RetCode.SUCCESS,
return response return response
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'): def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
result_dict = {"code": code, "message": message} result_dict = {"code": code, "message": message}
response = {} response = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
@ -264,7 +259,7 @@ def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'
return jsonify(response) return jsonify(response)
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
if data is None: if data is None:
return jsonify({"code": code, "message": message}) return jsonify({"code": code, "message": message})
else: else:
@ -286,7 +281,7 @@ def construct_error_response(e):
def token_required(func): def token_required(func):
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
authorization_str = flask_request.headers.get('Authorization') authorization_str = flask_request.headers.get("Authorization")
if not authorization_str: if not authorization_str:
return get_json_result(data=False, message="`Authorization` can't be empty") return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list = authorization_str.split() authorization_list = authorization_str.split()
@ -295,11 +290,8 @@ def token_required(func):
token = authorization_list[1] token = authorization_list[1]
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False, message='Authentication error: API key is invalid!', kwargs["tenant_id"] = objs[0].tenant_id
code=settings.RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated_function return decorated_function
@ -316,11 +308,11 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
return jsonify(response) return jsonify(response)
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR, def get_error_data_result(
): message="Sorry! Data missing!",
result_dict = { code=settings.RetCode.DATA_ERROR,
"code": code, ):
"message": message} result_dict = {"code": code, "message": message}
response = {} response = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
if value is None and key != "code": if value is None and key != "code":
@ -348,8 +340,7 @@ def valid_parameter(parameter, valid_values):
def dataset_readonly_fields(field_name): def dataset_readonly_fields(field_name):
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
"created_by", "document_count", "token_num", "status", "tenant_id", "id"]
def get_parser_config(chunk_method, parser_config): def get_parser_config(chunk_method, parser_config):
@ -358,8 +349,7 @@ def get_parser_config(chunk_method, parser_config):
if not chunk_method: if not chunk_method:
chunk_method = "naive" chunk_method = "naive"
key_mapping = { key_mapping = {
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}},
"raptor": {"use_raptor": False}},
"qa": {"raptor": {"use_raptor": False}}, "qa": {"raptor": {"use_raptor": False}},
"tag": None, "tag": None,
"resume": None, "resume": None,
@ -370,10 +360,10 @@ def get_parser_config(chunk_method, parser_config):
"laws": {"raptor": {"use_raptor": False}}, "laws": {"raptor": {"use_raptor": False}},
"presentation": {"raptor": {"use_raptor": False}}, "presentation": {"raptor": {"use_raptor": False}},
"one": None, "one": None,
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "entity_types": ["organization", "person", "location", "event", "time"]},
"entity_types": ["organization", "person", "location", "event", "time"]},
"email": None, "email": None,
"picture": None} "picture": None,
}
parser_config = key_mapping[chunk_method] parser_config = key_mapping[chunk_method]
return parser_config return parser_config
@ -421,21 +411,23 @@ def get_data_openai(id=None,
def valid_parser_config(parser_config): def valid_parser_config(parser_config):
if not parser_config: if not parser_config:
return return
scopes = set([ scopes = set(
"chunk_token_num", [
"delimiter", "chunk_token_num",
"raptor", "delimiter",
"graphrag", "raptor",
"layout_recognize", "graphrag",
"task_page_size", "layout_recognize",
"pages", "task_page_size",
"html4excel", "pages",
"auto_keywords", "html4excel",
"auto_questions", "auto_keywords",
"tag_kb_ids", "auto_questions",
"topn_tags", "tag_kb_ids",
"filename_embd_weight" "topn_tags",
]) "filename_embd_weight",
]
)
for k in parser_config.keys(): for k in parser_config.keys():
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}" assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
@ -480,5 +472,3 @@ def check_duplicate_ids(ids, id_type="item"):
# Return unique IDs and error messages # Return unique IDs and error messages
return list(set(ids)), duplicate_messages return list(set(ids)), duplicate_messages

View File

@ -59,6 +59,7 @@ class Base(ABC):
# Configure retry parameters # Configure retry parameters
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
self.is_tools = False
def _get_delay(self, attempt): def _get_delay(self, attempt):
"""Calculate retry delay time""" """Calculate retry delay time"""
@ -89,6 +90,91 @@ class Base(ABC):
else: else:
return ERROR_GENERIC return ERROR_GENERIC
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries):
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
assistant_output = response.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.choices[0].message.content
if not response.choices[0].message.tool_calls:
tk_count += self.total_token_count(response)
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
history.append(assistant_output)
for tool_call in response.choices[0].message.tool_calls:
name = tool_call.function.name
args = json.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, tk_count + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
assistant_output = final_response.choices[0].message
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += final_response.choices[0].message.content
if final_response.choices[0].finish_reason == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
return ans, tk_count
except Exception as e:
logging.exception("OpenAI cat_with_tools")
# Classify the error
error_code = self._classify_error(e)
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
if should_retry:
delay = self._get_delay(attempt)
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries - 1:
error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -127,6 +213,127 @@ class Base(ABC):
error_code = ERROR_MAX_RETRIES error_code = ERROR_MAX_RETRIES
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0 return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0
reasoning_start = False
finish_completion = False
final_tool_calls = {}
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
while not finish_completion:
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
if resp.choices[0].finish_reason != "stop":
continue
else:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens += tol
finish_reason = resp.choices[0].finish_reason
if finish_reason == "tool_calls" and final_tool_calls:
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
if name == "get_current_weather":
args = json.loads('{"location":"Shanghai"}')
else:
args = json.loads(tool_call.function.arguments)
except Exception:
continue
# args = json.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history.append(
{
"role": "assistant",
"refusal": "",
"content": "",
"audio": "",
"function_call": "",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": tool_call.function,
"type": "function",
},
],
}
)
# if tool_response.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# return ans, total_tokens + self.total_token_count(tool_response)
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
final_tool_calls = {}
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
continue
if finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_tokens + self.total_token_count(resp)
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -156,7 +363,7 @@ class Base(ABC):
if not tol: if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content) total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else: else:
total_tokens = tol total_tokens += tol
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
if is_chinese(ans): if is_chinese(ans):
@ -183,6 +390,7 @@ class Base(ABC):
def _calculate_dynamic_ctx(self, history): def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size""" """Calculate dynamic context window size"""
def count_tokens(text): def count_tokens(text):
"""Calculate token count for text""" """Calculate token count for text"""
# Simple calculation: 1 token per ASCII character # Simple calculation: 1 token per ASCII character
@ -216,6 +424,7 @@ class Base(ABC):
return ctx_size return ctx_size
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
@ -350,6 +559,8 @@ class BaiChuanChat(Base):
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
super().__init__(key, model_name, base_url=None)
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
@ -357,6 +568,78 @@ class QWenChat(Base):
if self.is_reasoning_model(self.model_name): if self.is_reasoning_model(self.model_name):
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
# if self.is_reasoning_model(self.model_name):
# return super().chat(system, history, gen_conf)
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
if not stream_flag:
from http import HTTPStatus
tools = self.tools
if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
assistant_output = response.output.choices[0].message
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans += "<think>" + ans + "</think>"
ans += response.output.choices[0].message.content
if "tool_calls" not in assistant_output:
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
history.append(assistant_output)
while "tool_calls" in assistant_output:
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
tool_name = assistant_output.tool_calls[0]["function"]["name"]
if tool_name:
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
history.append(tool_info)
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
if response.output.choices[0].get("finish_reason", "") == "length":
tk_count += self.total_token_count(response)
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, tk_count
tk_count += self.total_token_count(response)
assistant_output = response.output.choices[0].message
if assistant_output.content is None:
assistant_output.content = ""
history.append(response)
ans += assistant_output["content"]
return ans, tk_count
else:
return "**ERROR**: " + response.message, tk_count
else:
result_list = []
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
result_list.append(result)
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
if len(error_msg_list) > 0:
return "**ERROR**: " + "".join(error_msg_list), 0
else:
return "".join(result_list[:-1]), result_list[-1]
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
@ -393,6 +676,99 @@ class QWenChat(Base):
else: else:
return "".join(result_list[:-1]), result_list[-1] return "".join(result_list[:-1]), result_list[-1]
def _wrap_toolcall_message(self, old_message, message):
if not old_message:
return message
tool_call_id = message["tool_calls"][0].get("id")
if tool_call_id:
old_message.tool_calls[0]["id"] = tool_call_id
function = message.tool_calls[0]["function"]
if function:
if function.get("name"):
old_message.tool_calls[0]["function"]["name"] = function["name"]
if function.get("arguments"):
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
return old_message
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
from http import HTTPStatus
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
tk_count = 0
try:
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
tool_info = {"content": "", "role": "tool"}
toolcall_message = None
tool_name = ""
tool_arguments = ""
finish_completion = False
reasoning_start = False
while not finish_completion:
for resp in response:
if resp.status_code == HTTPStatus.OK:
assistant_output = resp.output.choices[0].message
ans = resp.output.choices[0].message.content
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
ans = resp.output.choices[0].message.reasoning_content
if not reasoning_start:
reasoning_start = True
ans = "<think>" + ans
else:
ans = ans + "</think>"
if "tool_calls" not in assistant_output:
reasoning_start = False
tk_count += self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
finish_reason = resp.output.choices[0]["finish_reason"]
if finish_reason == "stop":
finish_completion = True
yield ans
break
yield ans
continue
tk_count += self.total_token_count(resp)
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
if "tool_calls" in assistant_output:
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
if tool_call_finish_reason == "tool_calls":
try:
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool tool call error")
yield ans + "\n**ERROR**: " + str(e)
finish_completion = True
break
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
history.append(toolcall_message)
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
history.append(tool_info)
tool_info = {"content": "", "role": "tool"}
tool_name = ""
tool_arguments = ""
toolcall_message = None
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
else:
yield (
ans + "\n**ERROR**: " + resp.output.choices[0].message
if not re.search(r" (key|quota)", str(resp.message).lower())
else "Out of credit. Please set the API key in **settings > Model providers.**"
)
except Exception as e:
logging.exception(msg="_chat_streamly_with_tool")
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def _chat_streamly(self, system, history, gen_conf, incremental_output=True): def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
from http import HTTPStatus from http import HTTPStatus
@ -425,6 +801,13 @@ class QWenChat(Base):
yield tk_count yield tk_count
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
yield txt
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
@ -445,6 +828,8 @@ class QWenChat(Base):
class ZhipuChat(Base): class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo", **kwargs): def __init__(self, key, model_name="glm-3-turbo", **kwargs):
super().__init__(key, model_name, base_url=None)
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -504,6 +889,8 @@ class ZhipuChat(Base):
class OllamaChat(Base): class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
super().__init__(key, model_name, base_url=None)
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name self.model_name = model_name
@ -516,9 +903,7 @@ class OllamaChat(Base):
# Calculate context size # Calculate context size
ctx_size = self._calculate_dynamic_ctx(history) ctx_size = self._calculate_dynamic_ctx(history)
options = { options = {"num_ctx": ctx_size}
"num_ctx": ctx_size
}
if "temperature" in gen_conf: if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"] options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
@ -545,9 +930,7 @@ class OllamaChat(Base):
try: try:
# Calculate context size # Calculate context size
ctx_size = self._calculate_dynamic_ctx(history) ctx_size = self._calculate_dynamic_ctx(history)
options = { options = {"num_ctx": ctx_size}
"num_ctx": ctx_size
}
if "temperature" in gen_conf: if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"] options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
@ -561,7 +944,7 @@ class OllamaChat(Base):
ans = "" ans = ""
try: try:
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 ) response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10)
for resp in response: for resp in response:
if resp["done"]: if resp["done"]:
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -578,6 +961,8 @@ class OllamaChat(Base):
class LocalAIChat(Base): class LocalAIChat(Base):
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
super().__init__(key, model_name, base_url=None)
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1": if base_url.split("/")[-1] != "v1":
@ -613,6 +998,8 @@ class LocalLLM(Base):
return do_rpc return do_rpc
def __init__(self, key, model_name): def __init__(self, key, model_name):
super().__init__(key, model_name, base_url=None)
from jina import Client from jina import Client
self.client = Client(port=12345, protocol="grpc", asyncio=True) self.client = Client(port=12345, protocol="grpc", asyncio=True)
@ -659,6 +1046,8 @@ class LocalLLM(Base):
class VolcEngineChat(Base): class VolcEngineChat(Base):
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
super().__init__(key, model_name, base_url=None)
""" """
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
@ -677,6 +1066,8 @@ class MiniMaxChat(Base):
model_name, model_name,
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
): ):
super().__init__(key, model_name, base_url=None)
if not base_url: if not base_url:
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
self.base_url = base_url self.base_url = base_url
@ -755,6 +1146,8 @@ class MiniMaxChat(Base):
class MistralChat(Base): class MistralChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
from mistralai.client import MistralClient from mistralai.client import MistralClient
self.client = MistralClient(api_key=key) self.client = MistralClient(api_key=key)
@ -808,6 +1201,8 @@ class MistralChat(Base):
class BedrockChat(Base): class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
super().__init__(key, model_name, base_url=None)
import boto3 import boto3
self.bedrock_ak = json.loads(key).get("bedrock_ak", "") self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
@ -887,6 +1282,8 @@ class BedrockChat(Base):
class GeminiChat(Base): class GeminiChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
from google.generativeai import GenerativeModel, client from google.generativeai import GenerativeModel, client
client.configure(api_key=key) client.configure(api_key=key)
@ -947,6 +1344,8 @@ class GeminiChat(Base):
class GroqChat(Base): class GroqChat(Base):
def __init__(self, key, model_name, base_url=""): def __init__(self, key, model_name, base_url=""):
super().__init__(key, model_name, base_url=None)
from groq import Groq from groq import Groq
self.client = Groq(api_key=key) self.client = Groq(api_key=key)
@ -1049,6 +1448,8 @@ class PPIOChat(Base):
class CoHereChat(Base): class CoHereChat(Base):
def __init__(self, key, model_name, base_url=""): def __init__(self, key, model_name, base_url=""):
super().__init__(key, model_name, base_url=None)
from cohere import Client from cohere import Client
self.client = Client(api_key=key) self.client = Client(api_key=key)
@ -1171,6 +1572,8 @@ class YiChat(Base):
class ReplicateChat(Base): class ReplicateChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
from replicate.client import Client from replicate.client import Client
self.model_name = model_name self.model_name = model_name
@ -1218,6 +1621,8 @@ class ReplicateChat(Base):
class HunyuanChat(Base): class HunyuanChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client from tencentcloud.hunyuan.v20230901 import hunyuan_client
@ -1321,6 +1726,8 @@ class SparkChat(Base):
class BaiduYiyanChat(Base): class BaiduYiyanChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
import qianfan import qianfan
key = json.loads(key) key = json.loads(key)
@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base):
class AnthropicChat(Base): class AnthropicChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
import anthropic import anthropic
self.client = anthropic.Anthropic(api_key=key) self.client = anthropic.Anthropic(api_key=key)
@ -1452,6 +1861,8 @@ class AnthropicChat(Base):
class GoogleChat(Base): class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
super().__init__(key, model_name, base_url=None)
import base64 import base64
from google.oauth2 import service_account from google.oauth2 import service_account