mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 05:29:57 +08:00
Feat: add primitive support for function calls (#6840)
### What problem does this PR solve? This PR introduces **primitive support for function calls**, enabling the system to handle basic function call capabilities. However, this feature is currently experimental and **not yet enabled for general use**, as it is only supported by a subset of models, namely, Qwen and OpenAI models. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
a20439bf81
commit
dc2c74b249
@ -243,6 +243,11 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
msg = None
|
msg = None
|
||||||
msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)]
|
msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)]
|
||||||
|
|
||||||
|
# tools = get_tools()
|
||||||
|
# toolcall_session = SimpleFunctionCallServer()
|
||||||
|
tools = None
|
||||||
|
toolcall_session = None
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
# The value for the usage field on all chunks except for the last one will be null.
|
# The value for the usage field on all chunks except for the last one will be null.
|
||||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||||
@ -262,7 +267,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True):
|
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||||
@ -325,7 +330,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False):
|
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
|
@ -145,6 +145,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
else:
|
else:
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
|
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
|
||||||
|
if toolcall_session and tools:
|
||||||
|
chat_mdl.bind_tools(toolcall_session, tools)
|
||||||
|
|
||||||
bind_llm_ts = timer()
|
bind_llm_ts = timer()
|
||||||
|
|
||||||
@ -338,7 +341,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
|
langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
|
||||||
|
|
||||||
# Add a condition check to call the end method only if langfuse_tracer exists
|
# Add a condition check to call the end method only if langfuse_tracer exists
|
||||||
if langfuse_tracer and 'langfuse_generation' in locals():
|
if langfuse_tracer and "langfuse_generation" in locals():
|
||||||
langfuse_generation.end(output=langfuse_output)
|
langfuse_generation.end(output=langfuse_output)
|
||||||
|
|
||||||
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||||
|
@ -102,6 +102,9 @@ class TenantLLMService(CommonService):
|
|||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
if model_config:
|
if model_config:
|
||||||
model_config = model_config.to_dict()
|
model_config = model_config.to_dict()
|
||||||
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
if llm:
|
||||||
|
model_config["is_tools"] = llm[0].is_tools
|
||||||
if not model_config:
|
if not model_config:
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
@ -206,6 +209,8 @@ class LLMBundle:
|
|||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
self.max_length = model_config.get("max_tokens", 8192)
|
self.max_length = model_config.get("max_tokens", 8192)
|
||||||
|
|
||||||
|
self.is_tools = model_config.get("is_tools", False)
|
||||||
|
|
||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||||
if langfuse_keys:
|
if langfuse_keys:
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||||
@ -215,6 +220,11 @@ class LLMBundle:
|
|||||||
else:
|
else:
|
||||||
self.langfuse = None
|
self.langfuse = None
|
||||||
|
|
||||||
|
def bind_tools(self, toolcall_session, tools):
|
||||||
|
if not self.is_tools:
|
||||||
|
return
|
||||||
|
self.mdl.bind_tools(toolcall_session, tools)
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts})
|
generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts})
|
||||||
@ -307,11 +317,31 @@ class LLMBundle:
|
|||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
span.end()
|
span.end()
|
||||||
|
|
||||||
|
def _remove_reasoning_content(self, txt: str) -> str:
|
||||||
|
first_think_start = txt.find("<think>")
|
||||||
|
if first_think_start == -1:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
last_think_end = txt.rfind("</think>")
|
||||||
|
if last_think_end == -1:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
if last_think_end < first_think_start:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
return txt[last_think_end + len("</think>") :]
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
|
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
chat = self.mdl.chat
|
||||||
|
if self.is_tools and self.mdl.is_tools:
|
||||||
|
chat = self.mdl.chat_with_tools
|
||||||
|
|
||||||
|
txt, used_tokens = chat(system, history, gen_conf)
|
||||||
|
txt = self._remove_reasoning_content(txt)
|
||||||
|
|
||||||
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
@ -325,7 +355,12 @@ class LLMBundle:
|
|||||||
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
chat_streamly = self.mdl.chat_streamly
|
||||||
|
|
||||||
|
if self.is_tools and self.mdl.is_tools:
|
||||||
|
chat_streamly = self.mdl.chat_streamly_with_tools
|
||||||
|
|
||||||
|
for txt in chat_streamly(system, history, gen_conf):
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.end(output={"output": ans})
|
generation.end(output={"output": ans})
|
||||||
|
@ -13,9 +13,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import logging
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
@ -27,59 +27,60 @@ from uuid import uuid1
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import (
|
from flask import (
|
||||||
Response, jsonify, send_file, make_response,
|
Response,
|
||||||
|
jsonify,
|
||||||
|
make_response,
|
||||||
|
send_file,
|
||||||
|
)
|
||||||
|
from flask import (
|
||||||
request as flask_request,
|
request as flask_request,
|
||||||
)
|
)
|
||||||
from itsdangerous import URLSafeTimedSerializer
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from api.db.db_models import APIToken
|
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
|
from api.db.db_models import APIToken
|
||||||
|
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||||
|
|
||||||
from api.utils import CustomJSONEncoder, get_uuid
|
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||||
from api.utils import json_dumps
|
|
||||||
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC
|
|
||||||
|
|
||||||
requests.models.complexjson.dumps = functools.partial(
|
|
||||||
json.dumps, cls=CustomJSONEncoder)
|
|
||||||
|
|
||||||
|
|
||||||
def request(**kwargs):
|
def request(**kwargs):
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
stream = kwargs.pop('stream', sess.stream)
|
stream = kwargs.pop("stream", sess.stream)
|
||||||
timeout = kwargs.pop('timeout', None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
kwargs['headers'] = {
|
kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
|
||||||
k.replace(
|
|
||||||
'_',
|
|
||||||
'-').upper(): v for k,
|
|
||||||
v in kwargs.get(
|
|
||||||
'headers',
|
|
||||||
{}).items()}
|
|
||||||
prepped = requests.Request(**kwargs).prepare()
|
prepped = requests.Request(**kwargs).prepare()
|
||||||
|
|
||||||
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
||||||
timestamp = str(round(time() * 1000))
|
timestamp = str(round(time() * 1000))
|
||||||
nonce = str(uuid1())
|
nonce = str(uuid1())
|
||||||
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
|
signature = b64encode(
|
||||||
timestamp.encode('ascii'),
|
HMAC(
|
||||||
nonce.encode('ascii'),
|
settings.SECRET_KEY.encode("ascii"),
|
||||||
settings.HTTP_APP_KEY.encode('ascii'),
|
b"\n".join(
|
||||||
prepped.path_url.encode('ascii'),
|
[
|
||||||
prepped.body if kwargs.get('json') else b'',
|
timestamp.encode("ascii"),
|
||||||
urlencode(
|
nonce.encode("ascii"),
|
||||||
sorted(
|
settings.HTTP_APP_KEY.encode("ascii"),
|
||||||
kwargs['data'].items()),
|
prepped.path_url.encode("ascii"),
|
||||||
quote_via=quote,
|
prepped.body if kwargs.get("json") else b"",
|
||||||
safe='-._~').encode('ascii')
|
urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
|
||||||
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
|
]
|
||||||
]), 'sha1').digest()).decode('ascii')
|
),
|
||||||
|
"sha1",
|
||||||
|
).digest()
|
||||||
|
).decode("ascii")
|
||||||
|
|
||||||
prepped.headers.update({
|
prepped.headers.update(
|
||||||
'TIMESTAMP': timestamp,
|
{
|
||||||
'NONCE': nonce,
|
"TIMESTAMP": timestamp,
|
||||||
'APP-KEY': settings.HTTP_APP_KEY,
|
"NONCE": nonce,
|
||||||
'SIGNATURE': signature,
|
"APP-KEY": settings.HTTP_APP_KEY,
|
||||||
})
|
"SIGNATURE": signature,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return sess.send(prepped, stream=stream, timeout=timeout)
|
return sess.send(prepped, stream=stream, timeout=timeout)
|
||||||
|
|
||||||
@ -87,7 +88,7 @@ def request(**kwargs):
|
|||||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||||
"""Calculate the exponential backoff wait time."""
|
"""Calculate the exponential backoff wait time."""
|
||||||
# Will be zero if factor equals 0
|
# Will be zero if factor equals 0
|
||||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
|
||||||
# Full jitter according to
|
# Full jitter according to
|
||||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||||
if full_jitter:
|
if full_jitter:
|
||||||
@ -96,12 +97,9 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
|||||||
return max(0, countdown)
|
return max(0, countdown)
|
||||||
|
|
||||||
|
|
||||||
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
|
def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
|
||||||
message='Sorry! Data missing!'):
|
|
||||||
logging.exception(Exception(message))
|
logging.exception(Exception(message))
|
||||||
result_dict = {
|
result_dict = {"code": code, "message": message}
|
||||||
"code": code,
|
|
||||||
"message": message}
|
|
||||||
response = {}
|
response = {}
|
||||||
for key, value in result_dict.items():
|
for key, value in result_dict.items():
|
||||||
if value is None and key != "code":
|
if value is None and key != "code":
|
||||||
@ -119,23 +117,27 @@ def server_error_response(e):
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return get_json_result(
|
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||||
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
|
||||||
if repr(e).find("index_not_found_exception") >= 0:
|
if repr(e).find("index_not_found_exception") >= 0:
|
||||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR,
|
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
|
||||||
message="No chunk found, please upload file and parse it.")
|
|
||||||
|
|
||||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
def error_response(response_code, message=None):
|
def error_response(response_code, message=None):
|
||||||
if message is None:
|
if message is None:
|
||||||
message = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
|
message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
|
||||||
|
|
||||||
return Response(json.dumps({
|
return Response(
|
||||||
'message': message,
|
json.dumps(
|
||||||
'code': response_code,
|
{
|
||||||
}), status=response_code, mimetype='application/json')
|
"message": message,
|
||||||
|
"code": response_code,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status=response_code,
|
||||||
|
mimetype="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_request(*args, **kwargs):
|
def validate_request(*args, **kwargs):
|
||||||
@ -160,13 +162,10 @@ def validate_request(*args, **kwargs):
|
|||||||
if no_arguments or error_arguments:
|
if no_arguments or error_arguments:
|
||||||
error_string = ""
|
error_string = ""
|
||||||
if no_arguments:
|
if no_arguments:
|
||||||
error_string += "required argument are missing: {}; ".format(
|
error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
|
||||||
",".join(no_arguments))
|
|
||||||
if error_arguments:
|
if error_arguments:
|
||||||
error_string += "required argument values: {}".format(
|
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||||
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
||||||
return get_json_result(
|
|
||||||
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
|
||||||
return func(*_args, **_kwargs)
|
return func(*_args, **_kwargs)
|
||||||
|
|
||||||
return decorated_function
|
return decorated_function
|
||||||
@ -180,8 +179,7 @@ def not_allowed_parameters(*params):
|
|||||||
input_arguments = flask_request.json or flask_request.form.to_dict()
|
input_arguments = flask_request.json or flask_request.form.to_dict()
|
||||||
for param in params:
|
for param in params:
|
||||||
if param in input_arguments:
|
if param in input_arguments:
|
||||||
return get_json_result(
|
return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
||||||
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@ -190,14 +188,14 @@ def not_allowed_parameters(*params):
|
|||||||
|
|
||||||
|
|
||||||
def is_localhost(ip):
|
def is_localhost(ip):
|
||||||
return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
|
return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
|
||||||
|
|
||||||
|
|
||||||
def send_file_in_mem(data, filename):
|
def send_file_in_mem(data, filename):
|
||||||
if not isinstance(data, (str, bytes)):
|
if not isinstance(data, (str, bytes)):
|
||||||
data = json_dumps(data)
|
data = json_dumps(data)
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
data = data.encode('utf-8')
|
data = data.encode("utf-8")
|
||||||
|
|
||||||
f = BytesIO()
|
f = BytesIO()
|
||||||
f.write(data)
|
f.write(data)
|
||||||
@ -206,7 +204,7 @@ def send_file_in_mem(data, filename):
|
|||||||
return send_file(f, as_attachment=True, attachment_filename=filename)
|
return send_file(f, as_attachment=True, attachment_filename=filename)
|
||||||
|
|
||||||
|
|
||||||
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
|
||||||
response = {"code": code, "message": message, "data": data}
|
response = {"code": code, "message": message, "data": data}
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@ -214,27 +212,24 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
|
|||||||
def apikey_required(func):
|
def apikey_required(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def decorated_function(*args, **kwargs):
|
def decorated_function(*args, **kwargs):
|
||||||
token = flask_request.headers.get('Authorization').split()[1]
|
token = flask_request.headers.get("Authorization").split()[1]
|
||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return build_error_result(
|
return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
|
||||||
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
|
kwargs["tenant_id"] = objs[0].tenant_id
|
||||||
)
|
|
||||||
kwargs['tenant_id'] = objs[0].tenant_id
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
|
def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
|
||||||
response = {"code": code, "message": message}
|
response = {"code": code, "message": message}
|
||||||
response = jsonify(response)
|
response = jsonify(response)
|
||||||
response.status_code = code
|
response.status_code = code
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def construct_response(code=settings.RetCode.SUCCESS,
|
def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||||
message='success', data=None, auth=None):
|
|
||||||
result_dict = {"code": code, "message": message, "data": data}
|
result_dict = {"code": code, "message": message, "data": data}
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
for key, value in result_dict.items():
|
for key, value in result_dict.items():
|
||||||
@ -253,7 +248,7 @@ def construct_response(code=settings.RetCode.SUCCESS,
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
|
def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
|
||||||
result_dict = {"code": code, "message": message}
|
result_dict = {"code": code, "message": message}
|
||||||
response = {}
|
response = {}
|
||||||
for key, value in result_dict.items():
|
for key, value in result_dict.items():
|
||||||
@ -264,7 +259,7 @@ def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'
|
|||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
|
||||||
if data is None:
|
if data is None:
|
||||||
return jsonify({"code": code, "message": message})
|
return jsonify({"code": code, "message": message})
|
||||||
else:
|
else:
|
||||||
@ -286,7 +281,7 @@ def construct_error_response(e):
|
|||||||
def token_required(func):
|
def token_required(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def decorated_function(*args, **kwargs):
|
def decorated_function(*args, **kwargs):
|
||||||
authorization_str = flask_request.headers.get('Authorization')
|
authorization_str = flask_request.headers.get("Authorization")
|
||||||
if not authorization_str:
|
if not authorization_str:
|
||||||
return get_json_result(data=False, message="`Authorization` can't be empty")
|
return get_json_result(data=False, message="`Authorization` can't be empty")
|
||||||
authorization_list = authorization_str.split()
|
authorization_list = authorization_str.split()
|
||||||
@ -295,11 +290,8 @@ def token_required(func):
|
|||||||
token = authorization_list[1]
|
token = authorization_list[1]
|
||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
data=False, message='Authentication error: API key is invalid!',
|
kwargs["tenant_id"] = objs[0].tenant_id
|
||||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
|
||||||
)
|
|
||||||
kwargs['tenant_id'] = objs[0].tenant_id
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return decorated_function
|
return decorated_function
|
||||||
@ -316,11 +308,11 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
|||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
|
def get_error_data_result(
|
||||||
):
|
message="Sorry! Data missing!",
|
||||||
result_dict = {
|
code=settings.RetCode.DATA_ERROR,
|
||||||
"code": code,
|
):
|
||||||
"message": message}
|
result_dict = {"code": code, "message": message}
|
||||||
response = {}
|
response = {}
|
||||||
for key, value in result_dict.items():
|
for key, value in result_dict.items():
|
||||||
if value is None and key != "code":
|
if value is None and key != "code":
|
||||||
@ -348,8 +340,7 @@ def valid_parameter(parameter, valid_values):
|
|||||||
|
|
||||||
|
|
||||||
def dataset_readonly_fields(field_name):
|
def dataset_readonly_fields(field_name):
|
||||||
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time",
|
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
|
||||||
"created_by", "document_count", "token_num", "status", "tenant_id", "id"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser_config(chunk_method, parser_config):
|
def get_parser_config(chunk_method, parser_config):
|
||||||
@ -358,8 +349,7 @@ def get_parser_config(chunk_method, parser_config):
|
|||||||
if not chunk_method:
|
if not chunk_method:
|
||||||
chunk_method = "naive"
|
chunk_method = "naive"
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC",
|
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}},
|
||||||
"raptor": {"use_raptor": False}},
|
|
||||||
"qa": {"raptor": {"use_raptor": False}},
|
"qa": {"raptor": {"use_raptor": False}},
|
||||||
"tag": None,
|
"tag": None,
|
||||||
"resume": None,
|
"resume": None,
|
||||||
@ -370,10 +360,10 @@ def get_parser_config(chunk_method, parser_config):
|
|||||||
"laws": {"raptor": {"use_raptor": False}},
|
"laws": {"raptor": {"use_raptor": False}},
|
||||||
"presentation": {"raptor": {"use_raptor": False}},
|
"presentation": {"raptor": {"use_raptor": False}},
|
||||||
"one": None,
|
"one": None,
|
||||||
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
|
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "entity_types": ["organization", "person", "location", "event", "time"]},
|
||||||
"entity_types": ["organization", "person", "location", "event", "time"]},
|
|
||||||
"email": None,
|
"email": None,
|
||||||
"picture": None}
|
"picture": None,
|
||||||
|
}
|
||||||
parser_config = key_mapping[chunk_method]
|
parser_config = key_mapping[chunk_method]
|
||||||
return parser_config
|
return parser_config
|
||||||
|
|
||||||
@ -421,21 +411,23 @@ def get_data_openai(id=None,
|
|||||||
def valid_parser_config(parser_config):
|
def valid_parser_config(parser_config):
|
||||||
if not parser_config:
|
if not parser_config:
|
||||||
return
|
return
|
||||||
scopes = set([
|
scopes = set(
|
||||||
"chunk_token_num",
|
[
|
||||||
"delimiter",
|
"chunk_token_num",
|
||||||
"raptor",
|
"delimiter",
|
||||||
"graphrag",
|
"raptor",
|
||||||
"layout_recognize",
|
"graphrag",
|
||||||
"task_page_size",
|
"layout_recognize",
|
||||||
"pages",
|
"task_page_size",
|
||||||
"html4excel",
|
"pages",
|
||||||
"auto_keywords",
|
"html4excel",
|
||||||
"auto_questions",
|
"auto_keywords",
|
||||||
"tag_kb_ids",
|
"auto_questions",
|
||||||
"topn_tags",
|
"tag_kb_ids",
|
||||||
"filename_embd_weight"
|
"topn_tags",
|
||||||
])
|
"filename_embd_weight",
|
||||||
|
]
|
||||||
|
)
|
||||||
for k in parser_config.keys():
|
for k in parser_config.keys():
|
||||||
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
|
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
|
||||||
|
|
||||||
@ -480,5 +472,3 @@ def check_duplicate_ids(ids, id_type="item"):
|
|||||||
|
|
||||||
# Return unique IDs and error messages
|
# Return unique IDs and error messages
|
||||||
return list(set(ids)), duplicate_messages
|
return list(set(ids)), duplicate_messages
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ class Base(ABC):
|
|||||||
# Configure retry parameters
|
# Configure retry parameters
|
||||||
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
|
self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5))
|
||||||
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
|
self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0))
|
||||||
|
self.is_tools = False
|
||||||
|
|
||||||
def _get_delay(self, attempt):
|
def _get_delay(self, attempt):
|
||||||
"""Calculate retry delay time"""
|
"""Calculate retry delay time"""
|
||||||
@ -89,6 +90,91 @@ class Base(ABC):
|
|||||||
else:
|
else:
|
||||||
return ERROR_GENERIC
|
return ERROR_GENERIC
|
||||||
|
|
||||||
|
def bind_tools(self, toolcall_session, tools):
|
||||||
|
if not (toolcall_session and tools):
|
||||||
|
return
|
||||||
|
self.is_tools = True
|
||||||
|
self.toolcall_session = toolcall_session
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
del gen_conf["max_tokens"]
|
||||||
|
|
||||||
|
tools = self.tools
|
||||||
|
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
# Implement exponential backoff retry strategy
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||||
|
|
||||||
|
assistant_output = response.choices[0].message
|
||||||
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||||
|
ans += "<think>" + ans + "</think>"
|
||||||
|
ans += response.choices[0].message.content
|
||||||
|
|
||||||
|
if not response.choices[0].message.tool_calls:
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
if response.choices[0].finish_reason == "length":
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
history.append(assistant_output)
|
||||||
|
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
name = tool_call.function.name
|
||||||
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
tool_response = self.toolcall_session.tool_call(name, args)
|
||||||
|
# if tool_response.choices[0].finish_reason == "length":
|
||||||
|
# if is_chinese(ans):
|
||||||
|
# ans += LENGTH_NOTIFICATION_CN
|
||||||
|
# else:
|
||||||
|
# ans += LENGTH_NOTIFICATION_EN
|
||||||
|
# return ans, tk_count + self.total_token_count(tool_response)
|
||||||
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||||
|
|
||||||
|
final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf)
|
||||||
|
assistant_output = final_response.choices[0].message
|
||||||
|
if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||||
|
ans += "<think>" + ans + "</think>"
|
||||||
|
ans += final_response.choices[0].message.content
|
||||||
|
if final_response.choices[0].finish_reason == "length":
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, tk_count
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("OpenAI cat_with_tools")
|
||||||
|
# Classify the error
|
||||||
|
error_code = self._classify_error(e)
|
||||||
|
|
||||||
|
# Check if it's a rate limit error or server error and not the last attempt
|
||||||
|
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1
|
||||||
|
|
||||||
|
if should_retry:
|
||||||
|
delay = self._get_delay(attempt)
|
||||||
|
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
# For non-rate limit errors or the last attempt, return an error message
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
error_code = ERROR_MAX_RETRIES
|
||||||
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -127,6 +213,127 @@ class Base(ABC):
|
|||||||
error_code = ERROR_MAX_RETRIES
|
error_code = ERROR_MAX_RETRIES
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0
|
||||||
|
|
||||||
|
def _wrap_toolcall_message(self, stream):
|
||||||
|
final_tool_calls = {}
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||||
|
index = tool_call.index
|
||||||
|
|
||||||
|
if index not in final_tool_calls:
|
||||||
|
final_tool_calls[index] = tool_call
|
||||||
|
|
||||||
|
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||||
|
|
||||||
|
return final_tool_calls
|
||||||
|
|
||||||
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
del gen_conf["max_tokens"]
|
||||||
|
|
||||||
|
tools = self.tools
|
||||||
|
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
total_tokens = 0
|
||||||
|
reasoning_start = False
|
||||||
|
finish_completion = False
|
||||||
|
final_tool_calls = {}
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||||
|
while not finish_completion:
|
||||||
|
for resp in response:
|
||||||
|
if resp.choices[0].delta.tool_calls:
|
||||||
|
for tool_call in resp.choices[0].delta.tool_calls or []:
|
||||||
|
index = tool_call.index
|
||||||
|
|
||||||
|
if index not in final_tool_calls:
|
||||||
|
final_tool_calls[index] = tool_call
|
||||||
|
|
||||||
|
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
||||||
|
if resp.choices[0].finish_reason != "stop":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if not resp.choices:
|
||||||
|
continue
|
||||||
|
if not resp.choices[0].delta.content:
|
||||||
|
resp.choices[0].delta.content = ""
|
||||||
|
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||||
|
ans = ""
|
||||||
|
if not reasoning_start:
|
||||||
|
reasoning_start = True
|
||||||
|
ans = "<think>"
|
||||||
|
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||||
|
else:
|
||||||
|
reasoning_start = False
|
||||||
|
ans = resp.choices[0].delta.content
|
||||||
|
|
||||||
|
tol = self.total_token_count(resp)
|
||||||
|
if not tol:
|
||||||
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
|
else:
|
||||||
|
total_tokens += tol
|
||||||
|
|
||||||
|
finish_reason = resp.choices[0].finish_reason
|
||||||
|
if finish_reason == "tool_calls" and final_tool_calls:
|
||||||
|
for tool_call in final_tool_calls.values():
|
||||||
|
name = tool_call.function.name
|
||||||
|
try:
|
||||||
|
if name == "get_current_weather":
|
||||||
|
args = json.loads('{"location":"Shanghai"}')
|
||||||
|
else:
|
||||||
|
args = json.loads(tool_call.function.arguments)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
# args = json.loads(tool_call.function.arguments)
|
||||||
|
tool_response = self.toolcall_session.tool_call(name, args)
|
||||||
|
history.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"refusal": "",
|
||||||
|
"content": "",
|
||||||
|
"audio": "",
|
||||||
|
"function_call": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": tool_call.index,
|
||||||
|
"id": tool_call.id,
|
||||||
|
"function": tool_call.function,
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# if tool_response.choices[0].finish_reason == "length":
|
||||||
|
# if is_chinese(ans):
|
||||||
|
# ans += LENGTH_NOTIFICATION_CN
|
||||||
|
# else:
|
||||||
|
# ans += LENGTH_NOTIFICATION_EN
|
||||||
|
# return ans, total_tokens + self.total_token_count(tool_response)
|
||||||
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)})
|
||||||
|
final_tool_calls = {}
|
||||||
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||||
|
continue
|
||||||
|
if finish_reason == "length":
|
||||||
|
if is_chinese(ans):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, total_tokens + self.total_token_count(resp)
|
||||||
|
if finish_reason == "stop":
|
||||||
|
finish_completion = True
|
||||||
|
yield ans
|
||||||
|
break
|
||||||
|
yield ans
|
||||||
|
continue
|
||||||
|
|
||||||
|
except openai.APIError as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield total_tokens
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -156,7 +363,7 @@ class Base(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens = tol
|
total_tokens += tol
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
if is_chinese(ans):
|
if is_chinese(ans):
|
||||||
@ -183,6 +390,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
def _calculate_dynamic_ctx(self, history):
|
||||||
"""Calculate dynamic context window size"""
|
"""Calculate dynamic context window size"""
|
||||||
|
|
||||||
def count_tokens(text):
|
def count_tokens(text):
|
||||||
"""Calculate token count for text"""
|
"""Calculate token count for text"""
|
||||||
# Simple calculation: 1 token per ASCII character
|
# Simple calculation: 1 token per ASCII character
|
||||||
@ -216,6 +424,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
return ctx_size
|
return ctx_size
|
||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
@ -350,6 +559,8 @@ class BaiChuanChat(Base):
|
|||||||
|
|
||||||
class QWenChat(Base):
|
class QWenChat(Base):
|
||||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
@ -357,6 +568,78 @@ class QWenChat(Base):
|
|||||||
if self.is_reasoning_model(self.model_name):
|
if self.is_reasoning_model(self.model_name):
|
||||||
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||||
|
|
||||||
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]:
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
del gen_conf["max_tokens"]
|
||||||
|
# if self.is_reasoning_model(self.model_name):
|
||||||
|
# return super().chat(system, history, gen_conf)
|
||||||
|
|
||||||
|
stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true"
|
||||||
|
if not stream_flag:
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
tools = self.tools
|
||||||
|
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf)
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
assistant_output = response.output.choices[0].message
|
||||||
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||||
|
ans += "<think>" + ans + "</think>"
|
||||||
|
ans += response.output.choices[0].message.content
|
||||||
|
|
||||||
|
if "tool_calls" not in assistant_output:
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
history.append(assistant_output)
|
||||||
|
|
||||||
|
while "tool_calls" in assistant_output:
|
||||||
|
tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]}
|
||||||
|
tool_name = assistant_output.tool_calls[0]["function"]["name"]
|
||||||
|
if tool_name:
|
||||||
|
arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"])
|
||||||
|
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments)
|
||||||
|
history.append(tool_info)
|
||||||
|
|
||||||
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf)
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
tk_count += self.total_token_count(response)
|
||||||
|
assistant_output = response.output.choices[0].message
|
||||||
|
if assistant_output.content is None:
|
||||||
|
assistant_output.content = ""
|
||||||
|
history.append(response)
|
||||||
|
ans += assistant_output["content"]
|
||||||
|
return ans, tk_count
|
||||||
|
else:
|
||||||
|
return "**ERROR**: " + response.message, tk_count
|
||||||
|
else:
|
||||||
|
result_list = []
|
||||||
|
for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True):
|
||||||
|
result_list.append(result)
|
||||||
|
error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0]
|
||||||
|
if len(error_msg_list) > 0:
|
||||||
|
return "**ERROR**: " + "".join(error_msg_list), 0
|
||||||
|
else:
|
||||||
|
return "".join(result_list[:-1]), result_list[-1]
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
@ -393,6 +676,99 @@ class QWenChat(Base):
|
|||||||
else:
|
else:
|
||||||
return "".join(result_list[:-1]), result_list[-1]
|
return "".join(result_list[:-1]), result_list[-1]
|
||||||
|
|
||||||
|
def _wrap_toolcall_message(self, old_message, message):
|
||||||
|
if not old_message:
|
||||||
|
return message
|
||||||
|
tool_call_id = message["tool_calls"][0].get("id")
|
||||||
|
if tool_call_id:
|
||||||
|
old_message.tool_calls[0]["id"] = tool_call_id
|
||||||
|
function = message.tool_calls[0]["function"]
|
||||||
|
if function:
|
||||||
|
if function.get("name"):
|
||||||
|
old_message.tool_calls[0]["function"]["name"] = function["name"]
|
||||||
|
if function.get("arguments"):
|
||||||
|
old_message.tool_calls[0]["function"]["arguments"] += function["arguments"]
|
||||||
|
return old_message
|
||||||
|
|
||||||
|
def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
del gen_conf["max_tokens"]
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
try:
|
||||||
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
||||||
|
tool_info = {"content": "", "role": "tool"}
|
||||||
|
toolcall_message = None
|
||||||
|
tool_name = ""
|
||||||
|
tool_arguments = ""
|
||||||
|
finish_completion = False
|
||||||
|
reasoning_start = False
|
||||||
|
while not finish_completion:
|
||||||
|
for resp in response:
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
assistant_output = resp.output.choices[0].message
|
||||||
|
ans = resp.output.choices[0].message.content
|
||||||
|
if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output:
|
||||||
|
ans = resp.output.choices[0].message.reasoning_content
|
||||||
|
if not reasoning_start:
|
||||||
|
reasoning_start = True
|
||||||
|
ans = "<think>" + ans
|
||||||
|
else:
|
||||||
|
ans = ans + "</think>"
|
||||||
|
|
||||||
|
if "tool_calls" not in assistant_output:
|
||||||
|
reasoning_start = False
|
||||||
|
tk_count += self.total_token_count(resp)
|
||||||
|
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
if is_chinese([ans]):
|
||||||
|
ans += LENGTH_NOTIFICATION_CN
|
||||||
|
else:
|
||||||
|
ans += LENGTH_NOTIFICATION_EN
|
||||||
|
finish_reason = resp.output.choices[0]["finish_reason"]
|
||||||
|
if finish_reason == "stop":
|
||||||
|
finish_completion = True
|
||||||
|
yield ans
|
||||||
|
break
|
||||||
|
yield ans
|
||||||
|
continue
|
||||||
|
|
||||||
|
tk_count += self.total_token_count(resp)
|
||||||
|
toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output)
|
||||||
|
if "tool_calls" in assistant_output:
|
||||||
|
tool_call_finish_reason = resp.output.choices[0]["finish_reason"]
|
||||||
|
if tool_call_finish_reason == "tool_calls":
|
||||||
|
try:
|
||||||
|
tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(msg="_chat_streamly_with_tool tool call error")
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
finish_completion = True
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_name = toolcall_message.tool_calls[0]["function"]["name"]
|
||||||
|
history.append(toolcall_message)
|
||||||
|
tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments)
|
||||||
|
history.append(tool_info)
|
||||||
|
tool_info = {"content": "", "role": "tool"}
|
||||||
|
tool_name = ""
|
||||||
|
tool_arguments = ""
|
||||||
|
toolcall_message = None
|
||||||
|
response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf)
|
||||||
|
else:
|
||||||
|
yield (
|
||||||
|
ans + "\n**ERROR**: " + resp.output.choices[0].message
|
||||||
|
if not re.search(r" (key|quota)", str(resp.message).lower())
|
||||||
|
else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(msg="_chat_streamly_with_tool")
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
|
def _chat_streamly(self, system, history, gen_conf, incremental_output=True):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
@ -425,6 +801,13 @@ class QWenChat(Base):
|
|||||||
|
|
||||||
yield tk_count
|
yield tk_count
|
||||||
|
|
||||||
|
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True):
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
del gen_conf["max_tokens"]
|
||||||
|
|
||||||
|
for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output):
|
||||||
|
yield txt
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
@ -445,6 +828,8 @@ class QWenChat(Base):
|
|||||||
|
|
||||||
class ZhipuChat(Base):
|
class ZhipuChat(Base):
|
||||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
self.client = ZhipuAI(api_key=key)
|
self.client = ZhipuAI(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@ -504,6 +889,8 @@ class ZhipuChat(Base):
|
|||||||
|
|
||||||
class OllamaChat(Base):
|
class OllamaChat(Base):
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@ -516,9 +903,7 @@ class OllamaChat(Base):
|
|||||||
# Calculate context size
|
# Calculate context size
|
||||||
ctx_size = self._calculate_dynamic_ctx(history)
|
ctx_size = self._calculate_dynamic_ctx(history)
|
||||||
|
|
||||||
options = {
|
options = {"num_ctx": ctx_size}
|
||||||
"num_ctx": ctx_size
|
|
||||||
}
|
|
||||||
if "temperature" in gen_conf:
|
if "temperature" in gen_conf:
|
||||||
options["temperature"] = gen_conf["temperature"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
@ -545,9 +930,7 @@ class OllamaChat(Base):
|
|||||||
try:
|
try:
|
||||||
# Calculate context size
|
# Calculate context size
|
||||||
ctx_size = self._calculate_dynamic_ctx(history)
|
ctx_size = self._calculate_dynamic_ctx(history)
|
||||||
options = {
|
options = {"num_ctx": ctx_size}
|
||||||
"num_ctx": ctx_size
|
|
||||||
}
|
|
||||||
if "temperature" in gen_conf:
|
if "temperature" in gen_conf:
|
||||||
options["temperature"] = gen_conf["temperature"]
|
options["temperature"] = gen_conf["temperature"]
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
@ -561,7 +944,7 @@ class OllamaChat(Base):
|
|||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if resp["done"]:
|
if resp["done"]:
|
||||||
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||||
@ -578,6 +961,8 @@ class OllamaChat(Base):
|
|||||||
|
|
||||||
class LocalAIChat(Base):
|
class LocalAIChat(Base):
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
if base_url.split("/")[-1] != "v1":
|
if base_url.split("/")[-1] != "v1":
|
||||||
@ -613,6 +998,8 @@ class LocalLLM(Base):
|
|||||||
return do_rpc
|
return do_rpc
|
||||||
|
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from jina import Client
|
from jina import Client
|
||||||
|
|
||||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||||
@ -659,6 +1046,8 @@ class LocalLLM(Base):
|
|||||||
|
|
||||||
class VolcEngineChat(Base):
|
class VolcEngineChat(Base):
|
||||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||||
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
|
||||||
@ -677,6 +1066,8 @@ class MiniMaxChat(Base):
|
|||||||
model_name,
|
model_name,
|
||||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||||
):
|
):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@ -755,6 +1146,8 @@ class MiniMaxChat(Base):
|
|||||||
|
|
||||||
class MistralChat(Base):
|
class MistralChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from mistralai.client import MistralClient
|
from mistralai.client import MistralClient
|
||||||
|
|
||||||
self.client = MistralClient(api_key=key)
|
self.client = MistralClient(api_key=key)
|
||||||
@ -808,6 +1201,8 @@ class MistralChat(Base):
|
|||||||
|
|
||||||
class BedrockChat(Base):
|
class BedrockChat(Base):
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||||
@ -887,6 +1282,8 @@ class BedrockChat(Base):
|
|||||||
|
|
||||||
class GeminiChat(Base):
|
class GeminiChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from google.generativeai import GenerativeModel, client
|
from google.generativeai import GenerativeModel, client
|
||||||
|
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
@ -947,6 +1344,8 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
class GroqChat(Base):
|
class GroqChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=""):
|
def __init__(self, key, model_name, base_url=""):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
|
|
||||||
self.client = Groq(api_key=key)
|
self.client = Groq(api_key=key)
|
||||||
@ -1049,6 +1448,8 @@ class PPIOChat(Base):
|
|||||||
|
|
||||||
class CoHereChat(Base):
|
class CoHereChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=""):
|
def __init__(self, key, model_name, base_url=""):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from cohere import Client
|
from cohere import Client
|
||||||
|
|
||||||
self.client = Client(api_key=key)
|
self.client = Client(api_key=key)
|
||||||
@ -1171,6 +1572,8 @@ class YiChat(Base):
|
|||||||
|
|
||||||
class ReplicateChat(Base):
|
class ReplicateChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from replicate.client import Client
|
from replicate.client import Client
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -1218,6 +1621,8 @@ class ReplicateChat(Base):
|
|||||||
|
|
||||||
class HunyuanChat(Base):
|
class HunyuanChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
from tencentcloud.common import credential
|
from tencentcloud.common import credential
|
||||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||||
|
|
||||||
@ -1321,6 +1726,8 @@ class SparkChat(Base):
|
|||||||
|
|
||||||
class BaiduYiyanChat(Base):
|
class BaiduYiyanChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
import qianfan
|
import qianfan
|
||||||
|
|
||||||
key = json.loads(key)
|
key = json.loads(key)
|
||||||
@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base):
|
|||||||
|
|
||||||
class AnthropicChat(Base):
|
class AnthropicChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
self.client = anthropic.Anthropic(api_key=key)
|
self.client = anthropic.Anthropic(api_key=key)
|
||||||
@ -1452,6 +1861,8 @@ class AnthropicChat(Base):
|
|||||||
|
|
||||||
class GoogleChat(Base):
|
class GoogleChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
super().__init__(key, model_name, base_url=None)
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from google.oauth2 import service_account
|
from google.oauth2 import service_account
|
||||||
|
Loading…
x
Reference in New Issue
Block a user