diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 44d0d8ba..16b1c8a1 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -116,7 +116,7 @@ def create_agent_session(tenant_id, agent_id): for ans in canvas.run(stream=False): pass - + cvs.dsl = json.loads(str(canvas)) conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} API4ConversationService.save(**conv) @@ -243,6 +243,11 @@ def chat_completion_openai_like(tenant_id, chat_id): msg = None msg = [m for m in messages if m["role"] != "system" and (m["role"] != "assistant" or msg)] + # tools = get_tools() + # toolcall_session = SimpleFunctionCallServer() + tools = None + toolcall_session = None + if req.get("stream", True): # The value for the usage field on all chunks except for the last one will be null. # The usage field on the last chunk contains token usage statistics for the entire request. @@ -262,7 +267,7 @@ def chat_completion_openai_like(tenant_id, chat_id): } try: - for ans in chat(dia, msg, True): + for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools): answer = ans["answer"] reasoning_match = re.search(r"(.*?)", answer, flags=re.DOTALL) @@ -325,7 +330,7 @@ def chat_completion_openai_like(tenant_id, chat_id): return resp else: answer = None - for ans in chat(dia, msg, False): + for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools): # focus answer content only answer = ans break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f3c05813..e70b925d 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -145,6 +145,9 @@ def chat(dialog, messages, stream=True, **kwargs): chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") + if toolcall_session and tools: + chat_mdl.bind_tools(toolcall_session, tools) bind_llm_ts = timer() @@ -338,7 +341,7 @@ def chat(dialog, messages, stream=True, **kwargs): langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()} # Add a condition check to call the end method only if langfuse_tracer exists - if langfuse_tracer and 'langfuse_generation' in locals(): + if langfuse_tracer and "langfuse_generation" in locals(): langfuse_generation.end(output=langfuse_output) return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index d86e0bf0..ded4f7f3 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -102,6 +102,9 @@ class TenantLLMService(CommonService): mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm) if model_config: model_config = model_config.to_dict() + llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) + if llm: + model_config["is_tools"] = llm[0].is_tools if not model_config: if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) @@ -206,6 +209,8 @@ class LLMBundle: model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) self.max_length = model_config.get("max_tokens", 8192) + self.is_tools = model_config.get("is_tools", False) + langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) if langfuse_keys: langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) @@ -215,6 +220,11 @@ class LLMBundle: else: self.langfuse = None + def bind_tools(self, toolcall_session, tools): + if not self.is_tools: + return + self.mdl.bind_tools(toolcall_session, tools) + def encode(self, texts: list): if self.langfuse: generation = self.trace.generation(name="encode", model=self.llm_name, input={"texts": texts}) @@ -307,11 +317,31 @@ class LLMBundle: if self.langfuse: span.end() + def _remove_reasoning_content(self, txt: str) -> str: + first_think_start = txt.find("") + if first_think_start == -1: + return txt + + last_think_end = txt.rfind("") + if last_think_end == -1: + return txt + + if last_think_end < first_think_start: + return txt + + return txt[last_think_end + len("") :] + def chat(self, system, history, gen_conf): if self.langfuse: generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history}) - txt, used_tokens = self.mdl.chat(system, history, gen_conf) + chat = self.mdl.chat + if self.is_tools and self.mdl.is_tools: + chat = self.mdl.chat_with_tools + + txt, used_tokens = chat(system, history, gen_conf) + txt = self._remove_reasoning_content(txt) + if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) @@ -325,7 +355,12 @@ class LLMBundle: generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) ans = "" - for txt in self.mdl.chat_streamly(system, history, gen_conf): + chat_streamly = self.mdl.chat_streamly + + if self.is_tools and self.mdl.is_tools: + chat_streamly = self.mdl.chat_streamly_with_tools + + for txt in chat_streamly(system, history, gen_conf): if isinstance(txt, int): if self.langfuse: generation.end(output={"output": ans}) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 593ec625..e39b9a1e 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import functools import json +import logging import random import time from base64 import b64encode @@ -27,59 +27,60 @@ from uuid import uuid1 import requests from flask import ( - Response, jsonify, send_file, make_response, + Response, + jsonify, + make_response, + send_file, +) +from flask import ( request as flask_request, ) from itsdangerous import URLSafeTimedSerializer from werkzeug.http import HTTP_STATUS_CODES -from api.db.db_models import APIToken from api import settings +from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC +from api.db.db_models import APIToken +from api.utils import CustomJSONEncoder, get_uuid, json_dumps -from api.utils import CustomJSONEncoder, get_uuid -from api.utils import json_dumps -from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC - -requests.models.complexjson.dumps = functools.partial( - json.dumps, cls=CustomJSONEncoder) +requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) def request(**kwargs): sess = requests.Session() - stream = kwargs.pop('stream', sess.stream) - timeout = kwargs.pop('timeout', None) - kwargs['headers'] = { - k.replace( - '_', - '-').upper(): v for k, - v in kwargs.get( - 'headers', - {}).items()} + stream = kwargs.pop("stream", sess.stream) + timeout = kwargs.pop("timeout", None) + kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()} prepped = requests.Request(**kwargs).prepare() if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY: timestamp = str(round(time() * 1000)) nonce = str(uuid1()) - signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([ - timestamp.encode('ascii'), - nonce.encode('ascii'), - settings.HTTP_APP_KEY.encode('ascii'), - prepped.path_url.encode('ascii'), - prepped.body if kwargs.get('json') else b'', - urlencode( - sorted( - kwargs['data'].items()), - quote_via=quote, - safe='-._~').encode('ascii') - if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', - ]), 'sha1').digest()).decode('ascii') + signature = b64encode( + HMAC( + settings.SECRET_KEY.encode("ascii"), + b"\n".join( + [ + timestamp.encode("ascii"), + nonce.encode("ascii"), + settings.HTTP_APP_KEY.encode("ascii"), + prepped.path_url.encode("ascii"), + prepped.body if kwargs.get("json") else b"", + urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"", + ] + ), + "sha1", + ).digest() + ).decode("ascii") - prepped.headers.update({ - 'TIMESTAMP': timestamp, - 'NONCE': nonce, - 'APP-KEY': settings.HTTP_APP_KEY, - 'SIGNATURE': signature, - }) + prepped.headers.update( + { + "TIMESTAMP": timestamp, + "NONCE": nonce, + "APP-KEY": settings.HTTP_APP_KEY, + "SIGNATURE": signature, + } + ) return sess.send(prepped, stream=stream, timeout=timeout) @@ -87,7 +88,7 @@ def request(**kwargs): def get_exponential_backoff_interval(retries, full_jitter=False): """Calculate the exponential backoff wait time.""" # Will be zero if factor equals 0 - countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) + countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries)) # Full jitter according to # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ if full_jitter: @@ -96,12 +97,9 @@ def get_exponential_backoff_interval(retries, full_jitter=False): return max(0, countdown) -def get_data_error_result(code=settings.RetCode.DATA_ERROR, - message='Sorry! Data missing!'): +def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"): logging.exception(Exception(message)) - result_dict = { - "code": code, - "message": message} + result_dict = {"code": code, "message": message} response = {} for key, value in result_dict.items(): if value is None and key != "code": @@ -119,23 +117,27 @@ def server_error_response(e): except BaseException: pass if len(e.args) > 1: - return get_json_result( - code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) + return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) if repr(e).find("index_not_found_exception") >= 0: - return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, - message="No chunk found, please upload file and parse it.") + return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) def error_response(response_code, message=None): if message is None: - message = HTTP_STATUS_CODES.get(response_code, 'Unknown Error') + message = HTTP_STATUS_CODES.get(response_code, "Unknown Error") - return Response(json.dumps({ - 'message': message, - 'code': response_code, - }), status=response_code, mimetype='application/json') + return Response( + json.dumps( + { + "message": message, + "code": response_code, + } + ), + status=response_code, + mimetype="application/json", + ) def validate_request(*args, **kwargs): @@ -160,13 +162,10 @@ def validate_request(*args, **kwargs): if no_arguments or error_arguments: error_string = "" if no_arguments: - error_string += "required argument are missing: {}; ".format( - ",".join(no_arguments)) + error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) if error_arguments: - error_string += "required argument values: {}".format( - ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) - return get_json_result( - code=settings.RetCode.ARGUMENT_ERROR, message=error_string) + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string) return func(*_args, **_kwargs) return decorated_function @@ -180,8 +179,7 @@ def not_allowed_parameters(*params): input_arguments = flask_request.json or flask_request.form.to_dict() for param in params: if param in input_arguments: - return get_json_result( - code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") + return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return f(*args, **kwargs) return wrapper @@ -190,14 +188,14 @@ def not_allowed_parameters(*params): def is_localhost(ip): - return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'} + return ip in {"127.0.0.1", "::1", "[::1]", "localhost"} def send_file_in_mem(data, filename): if not isinstance(data, (str, bytes)): data = json_dumps(data) if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") f = BytesIO() f.write(data) @@ -206,7 +204,7 @@ def send_file_in_mem(data, filename): return send_file(f, as_attachment=True, attachment_filename=filename) -def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): +def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None): response = {"code": code, "message": message, "data": data} return jsonify(response) @@ -214,27 +212,24 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None) def apikey_required(func): @wraps(func) def decorated_function(*args, **kwargs): - token = flask_request.headers.get('Authorization').split()[1] + token = flask_request.headers.get("Authorization").split()[1] objs = APIToken.query(token=token) if not objs: - return build_error_result( - message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN - ) - kwargs['tenant_id'] = objs[0].tenant_id + return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN) + kwargs["tenant_id"] = objs[0].tenant_id return func(*args, **kwargs) return decorated_function -def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'): +def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"): response = {"code": code, "message": message} response = jsonify(response) response.status_code = code return response -def construct_response(code=settings.RetCode.SUCCESS, - message='success', data=None, auth=None): +def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None): result_dict = {"code": code, "message": message, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -253,7 +248,7 @@ def construct_response(code=settings.RetCode.SUCCESS, return response -def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'): +def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"): result_dict = {"code": code, "message": message} response = {} for key, value in result_dict.items(): @@ -264,7 +259,7 @@ def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing' return jsonify(response) -def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): +def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None): if data is None: return jsonify({"code": code, "message": message}) else: @@ -286,7 +281,7 @@ def construct_error_response(e): def token_required(func): @wraps(func) def decorated_function(*args, **kwargs): - authorization_str = flask_request.headers.get('Authorization') + authorization_str = flask_request.headers.get("Authorization") if not authorization_str: return get_json_result(data=False, message="`Authorization` can't be empty") authorization_list = authorization_str.split() @@ -295,11 +290,8 @@ def token_required(func): token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!', - code=settings.RetCode.AUTHENTICATION_ERROR - ) - kwargs['tenant_id'] = objs[0].tenant_id + return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR) + kwargs["tenant_id"] = objs[0].tenant_id return func(*args, **kwargs) return decorated_function @@ -316,11 +308,11 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None): return jsonify(response) -def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR, - ): - result_dict = { - "code": code, - "message": message} +def get_error_data_result( + message="Sorry! Data missing!", + code=settings.RetCode.DATA_ERROR, +): + result_dict = {"code": code, "message": message} response = {} for key, value in result_dict.items(): if value is None and key != "code": @@ -348,8 +340,7 @@ def valid_parameter(parameter, valid_values): def dataset_readonly_fields(field_name): - return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", - "created_by", "document_count", "token_num", "status", "tenant_id", "id"] + return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"] def get_parser_config(chunk_method, parser_config): @@ -358,8 +349,7 @@ def get_parser_config(chunk_method, parser_config): if not chunk_method: chunk_method = "naive" key_mapping = { - "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", - "raptor": {"use_raptor": False}}, + "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}}, "qa": {"raptor": {"use_raptor": False}}, "tag": None, "resume": None, @@ -370,10 +360,10 @@ def get_parser_config(chunk_method, parser_config): "laws": {"raptor": {"use_raptor": False}}, "presentation": {"raptor": {"use_raptor": False}}, "one": None, - "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", - "entity_types": ["organization", "person", "location", "event", "time"]}, + "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "entity_types": ["organization", "person", "location", "event", "time"]}, "email": None, - "picture": None} + "picture": None, + } parser_config = key_mapping[chunk_method] return parser_config @@ -421,21 +411,23 @@ def get_data_openai(id=None, def valid_parser_config(parser_config): if not parser_config: return - scopes = set([ - "chunk_token_num", - "delimiter", - "raptor", - "graphrag", - "layout_recognize", - "task_page_size", - "pages", - "html4excel", - "auto_keywords", - "auto_questions", - "tag_kb_ids", - "topn_tags", - "filename_embd_weight" - ]) + scopes = set( + [ + "chunk_token_num", + "delimiter", + "raptor", + "graphrag", + "layout_recognize", + "task_page_size", + "pages", + "html4excel", + "auto_keywords", + "auto_questions", + "tag_kb_ids", + "topn_tags", + "filename_embd_weight", + ] + ) for k in parser_config.keys(): assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}" @@ -457,7 +449,7 @@ def check_duplicate_ids(ids, id_type="item"): """ Check for duplicate IDs in a list and return unique IDs and error messages. - Args: + Args: ids (list): List of IDs to check for duplicates id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk') @@ -468,17 +460,15 @@ def check_duplicate_ids(ids, id_type="item"): """ id_count = {} duplicate_messages = [] - + # Count occurrences of each ID for id_value in ids: id_count[id_value] = id_count.get(id_value, 0) + 1 - + # Check for duplicates for id_value, count in id_count.items(): if count > 1: duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}") - + # Return unique IDs and error messages return list(set(ids)), duplicate_messages - - diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 8030c945..1cc94ab1 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -59,6 +59,7 @@ class Base(ABC): # Configure retry parameters self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) + self.is_tools = False def _get_delay(self, attempt): """Calculate retry delay time""" @@ -89,6 +90,91 @@ class Base(ABC): else: return ERROR_GENERIC + def bind_tools(self, toolcall_session, tools): + if not (toolcall_session and tools): + return + self.is_tools = True + self.toolcall_session = toolcall_session + self.tools = tools + + def chat_with_tools(self, system: str, history: list, gen_conf: dict): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + tools = self.tools + + if system: + history.insert(0, {"role": "system", "content": system}) + + ans = "" + tk_count = 0 + # Implement exponential backoff retry strategy + for attempt in range(self.max_retries): + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) + + assistant_output = response.choices[0].message + if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: + ans += "" + ans + "" + ans += response.choices[0].message.content + + if not response.choices[0].message.tool_calls: + tk_count += self.total_token_count(response) + if response.choices[0].finish_reason == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + + tk_count += self.total_token_count(response) + history.append(assistant_output) + + for tool_call in response.choices[0].message.tool_calls: + name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + + tool_response = self.toolcall_session.tool_call(name, args) + # if tool_response.choices[0].finish_reason == "length": + # if is_chinese(ans): + # ans += LENGTH_NOTIFICATION_CN + # else: + # ans += LENGTH_NOTIFICATION_EN + # return ans, tk_count + self.total_token_count(tool_response) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + + final_response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=tools, **gen_conf) + assistant_output = final_response.choices[0].message + if "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: + ans += "" + ans + "" + ans += final_response.choices[0].message.content + if final_response.choices[0].finish_reason == "length": + tk_count += self.total_token_count(response) + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + return ans, tk_count + + except Exception as e: + logging.exception("OpenAI cat_with_tools") + # Classify the error + error_code = self._classify_error(e) + + # Check if it's a rate limit error or server error and not the last attempt + should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 + + if should_retry: + delay = self._get_delay(attempt) + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + time.sleep(delay) + else: + # For non-rate limit errors or the last attempt, return an error message + if attempt == self.max_retries - 1: + error_code = ERROR_MAX_RETRIES + return f"{ERROR_PREFIX}: {error_code} - {str(e)}", 0 + def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) @@ -127,6 +213,127 @@ class Base(ABC): error_code = ERROR_MAX_RETRIES return f"{ERROR_PREFIX}: {error_code} - {str(e)}. response: {response}", 0 + def _wrap_toolcall_message(self, stream): + final_tool_calls = {} + + for chunk in stream: + for tool_call in chunk.choices[0].delta.tool_calls or []: + index = tool_call.index + + if index not in final_tool_calls: + final_tool_calls[index] = tool_call + + final_tool_calls[index].function.arguments += tool_call.function.arguments + + return final_tool_calls + + def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + tools = self.tools + + if system: + history.insert(0, {"role": "system", "content": system}) + + ans = "" + total_tokens = 0 + reasoning_start = False + finish_completion = False + final_tool_calls = {} + try: + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) + while not finish_completion: + for resp in response: + if resp.choices[0].delta.tool_calls: + for tool_call in resp.choices[0].delta.tool_calls or []: + index = tool_call.index + + if index not in final_tool_calls: + final_tool_calls[index] = tool_call + + final_tool_calls[index].function.arguments += tool_call.function.arguments + if resp.choices[0].finish_reason != "stop": + continue + else: + if not resp.choices: + continue + if not resp.choices[0].delta.content: + resp.choices[0].delta.content = "" + if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += resp.choices[0].delta.reasoning_content + "" + else: + reasoning_start = False + ans = resp.choices[0].delta.content + + tol = self.total_token_count(resp) + if not tol: + total_tokens += num_tokens_from_string(resp.choices[0].delta.content) + else: + total_tokens += tol + + finish_reason = resp.choices[0].finish_reason + if finish_reason == "tool_calls" and final_tool_calls: + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + if name == "get_current_weather": + args = json.loads('{"location":"Shanghai"}') + else: + args = json.loads(tool_call.function.arguments) + except Exception: + continue + # args = json.loads(tool_call.function.arguments) + tool_response = self.toolcall_session.tool_call(name, args) + history.append( + { + "role": "assistant", + "refusal": "", + "content": "", + "audio": "", + "function_call": "", + "tool_calls": [ + { + "index": tool_call.index, + "id": tool_call.id, + "function": tool_call.function, + "type": "function", + }, + ], + } + ) + # if tool_response.choices[0].finish_reason == "length": + # if is_chinese(ans): + # ans += LENGTH_NOTIFICATION_CN + # else: + # ans += LENGTH_NOTIFICATION_EN + # return ans, total_tokens + self.total_token_count(tool_response) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) + final_tool_calls = {} + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) + continue + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, total_tokens + self.total_token_count(resp) + if finish_reason == "stop": + finish_completion = True + yield ans + break + yield ans + continue + + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) @@ -156,7 +363,7 @@ class Base(ABC): if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: - total_tokens = tol + total_tokens += tol if resp.choices[0].finish_reason == "length": if is_chinese(ans): @@ -180,9 +387,10 @@ class Base(ABC): except Exception: pass return 0 - + def _calculate_dynamic_ctx(self, history): """Calculate dynamic context window size""" + def count_tokens(text): """Calculate token count for text""" # Simple calculation: 1 token per ASCII character @@ -207,15 +415,16 @@ class Base(ABC): # Apply 1.2x buffer ratio total_tokens_with_buffer = int(total_tokens * 1.2) - + if total_tokens_with_buffer <= 8192: ctx_size = 8192 else: ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 ctx_size = ctx_multiplier * 8192 - + return ctx_size + class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): if not base_url: @@ -350,6 +559,8 @@ class BaiChuanChat(Base): class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): + super().__init__(key, model_name, base_url=None) + import dashscope dashscope.api_key = key @@ -357,6 +568,78 @@ class QWenChat(Base): if self.is_reasoning_model(self.model_name): super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") + def chat_with_tools(self, system: str, history: list, gen_conf: dict) -> tuple[str, int]: + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + # if self.is_reasoning_model(self.model_name): + # return super().chat(system, history, gen_conf) + + stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" + if not stream_flag: + from http import HTTPStatus + + tools = self.tools + + if system: + history.insert(0, {"role": "system", "content": system}) + + response = Generation.call(self.model_name, messages=history, result_format="message", tools=tools, **gen_conf) + ans = "" + tk_count = 0 + if response.status_code == HTTPStatus.OK: + assistant_output = response.output.choices[0].message + if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: + ans += "" + ans + "" + ans += response.output.choices[0].message.content + + if "tool_calls" not in assistant_output: + tk_count += self.total_token_count(response) + if response.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + + tk_count += self.total_token_count(response) + history.append(assistant_output) + + while "tool_calls" in assistant_output: + tool_info = {"content": "", "role": "tool", "tool_call_id": assistant_output.tool_calls[0]["id"]} + tool_name = assistant_output.tool_calls[0]["function"]["name"] + if tool_name: + arguments = json.loads(assistant_output.tool_calls[0]["function"]["arguments"]) + tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=arguments) + history.append(tool_info) + + response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, **gen_conf) + if response.output.choices[0].get("finish_reason", "") == "length": + tk_count += self.total_token_count(response) + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + return ans, tk_count + + tk_count += self.total_token_count(response) + assistant_output = response.output.choices[0].message + if assistant_output.content is None: + assistant_output.content = "" + history.append(response) + ans += assistant_output["content"] + return ans, tk_count + else: + return "**ERROR**: " + response.message, tk_count + else: + result_list = [] + for result in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=True): + result_list.append(result) + error_msg_list = [result for result in result_list if str(result).find("**ERROR**") >= 0] + if len(error_msg_list) > 0: + return "**ERROR**: " + "".join(error_msg_list), 0 + else: + return "".join(result_list[:-1]), result_list[-1] + def chat(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -393,6 +676,99 @@ class QWenChat(Base): else: return "".join(result_list[:-1]), result_list[-1] + def _wrap_toolcall_message(self, old_message, message): + if not old_message: + return message + tool_call_id = message["tool_calls"][0].get("id") + if tool_call_id: + old_message.tool_calls[0]["id"] = tool_call_id + function = message.tool_calls[0]["function"] + if function: + if function.get("name"): + old_message.tool_calls[0]["function"]["name"] = function["name"] + if function.get("arguments"): + old_message.tool_calls[0]["function"]["arguments"] += function["arguments"] + return old_message + + def _chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): + from http import HTTPStatus + + if system: + history.insert(0, {"role": "system", "content": system}) + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + ans = "" + tk_count = 0 + try: + response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) + tool_info = {"content": "", "role": "tool"} + toolcall_message = None + tool_name = "" + tool_arguments = "" + finish_completion = False + reasoning_start = False + while not finish_completion: + for resp in response: + if resp.status_code == HTTPStatus.OK: + assistant_output = resp.output.choices[0].message + ans = resp.output.choices[0].message.content + if not ans and "tool_calls" not in assistant_output and "reasoning_content" in assistant_output: + ans = resp.output.choices[0].message.reasoning_content + if not reasoning_start: + reasoning_start = True + ans = "" + ans + else: + ans = ans + "" + + if "tool_calls" not in assistant_output: + reasoning_start = False + tk_count += self.total_token_count(resp) + if resp.output.choices[0].get("finish_reason", "") == "length": + if is_chinese([ans]): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + finish_reason = resp.output.choices[0]["finish_reason"] + if finish_reason == "stop": + finish_completion = True + yield ans + break + yield ans + continue + + tk_count += self.total_token_count(resp) + toolcall_message = self._wrap_toolcall_message(toolcall_message, assistant_output) + if "tool_calls" in assistant_output: + tool_call_finish_reason = resp.output.choices[0]["finish_reason"] + if tool_call_finish_reason == "tool_calls": + try: + tool_arguments = json.loads(toolcall_message.tool_calls[0]["function"]["arguments"]) + except Exception as e: + logging.exception(msg="_chat_streamly_with_tool tool call error") + yield ans + "\n**ERROR**: " + str(e) + finish_completion = True + break + + tool_name = toolcall_message.tool_calls[0]["function"]["name"] + history.append(toolcall_message) + tool_info["content"] = self.toolcall_session.tool_call(name=tool_name, arguments=tool_arguments) + history.append(tool_info) + tool_info = {"content": "", "role": "tool"} + tool_name = "" + tool_arguments = "" + toolcall_message = None + response = Generation.call(self.model_name, messages=history, result_format="message", tools=self.tools, stream=True, incremental_output=incremental_output, **gen_conf) + else: + yield ( + ans + "\n**ERROR**: " + resp.output.choices[0].message + if not re.search(r" (key|quota)", str(resp.message).lower()) + else "Out of credit. Please set the API key in **settings > Model providers.**" + ) + except Exception as e: + logging.exception(msg="_chat_streamly_with_tool") + yield ans + "\n**ERROR**: " + str(e) + yield tk_count + def _chat_streamly(self, system, history, gen_conf, incremental_output=True): from http import HTTPStatus @@ -425,6 +801,13 @@ class QWenChat(Base): yield tk_count + def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict, incremental_output=True): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + + for txt in self._chat_streamly_with_tools(system, history, gen_conf, incremental_output=incremental_output): + yield txt + def chat_streamly(self, system, history, gen_conf): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -445,6 +828,8 @@ class QWenChat(Base): class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): + super().__init__(key, model_name, base_url=None) + self.client = ZhipuAI(api_key=key) self.model_name = model_name @@ -504,6 +889,8 @@ class ZhipuChat(Base): class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): + super().__init__(key, model_name, base_url=None) + self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) self.model_name = model_name @@ -515,10 +902,8 @@ class OllamaChat(Base): try: # Calculate context size ctx_size = self._calculate_dynamic_ctx(history) - - options = { - "num_ctx": ctx_size - } + + options = {"num_ctx": ctx_size} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: @@ -545,9 +930,7 @@ class OllamaChat(Base): try: # Calculate context size ctx_size = self._calculate_dynamic_ctx(history) - options = { - "num_ctx": ctx_size - } + options = {"num_ctx": ctx_size} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: @@ -561,7 +944,7 @@ class OllamaChat(Base): ans = "" try: - response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 ) + response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10) for resp in response: if resp["done"]: token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) @@ -578,6 +961,8 @@ class OllamaChat(Base): class LocalAIChat(Base): def __init__(self, key, model_name, base_url): + super().__init__(key, model_name, base_url=None) + if not base_url: raise ValueError("Local llm url cannot be None") if base_url.split("/")[-1] != "v1": @@ -613,6 +998,8 @@ class LocalLLM(Base): return do_rpc def __init__(self, key, model_name): + super().__init__(key, model_name, base_url=None) + from jina import Client self.client = Client(port=12345, protocol="grpc", asyncio=True) @@ -659,6 +1046,8 @@ class LocalLLM(Base): class VolcEngineChat(Base): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): + super().__init__(key, model_name, base_url=None) + """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use @@ -677,6 +1066,8 @@ class MiniMaxChat(Base): model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", ): + super().__init__(key, model_name, base_url=None) + if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" self.base_url = base_url @@ -755,6 +1146,8 @@ class MiniMaxChat(Base): class MistralChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + from mistralai.client import MistralClient self.client = MistralClient(api_key=key) @@ -808,6 +1201,8 @@ class MistralChat(Base): class BedrockChat(Base): def __init__(self, key, model_name, **kwargs): + super().__init__(key, model_name, base_url=None) + import boto3 self.bedrock_ak = json.loads(key).get("bedrock_ak", "") @@ -887,6 +1282,8 @@ class BedrockChat(Base): class GeminiChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + from google.generativeai import GenerativeModel, client client.configure(api_key=key) @@ -947,6 +1344,8 @@ class GeminiChat(Base): class GroqChat(Base): def __init__(self, key, model_name, base_url=""): + super().__init__(key, model_name, base_url=None) + from groq import Groq self.client = Groq(api_key=key) @@ -1049,6 +1448,8 @@ class PPIOChat(Base): class CoHereChat(Base): def __init__(self, key, model_name, base_url=""): + super().__init__(key, model_name, base_url=None) + from cohere import Client self.client = Client(api_key=key) @@ -1171,6 +1572,8 @@ class YiChat(Base): class ReplicateChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + from replicate.client import Client self.model_name = model_name @@ -1218,6 +1621,8 @@ class ReplicateChat(Base): class HunyuanChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -1321,6 +1726,8 @@ class SparkChat(Base): class BaiduYiyanChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + import qianfan key = json.loads(key) @@ -1372,6 +1779,8 @@ class BaiduYiyanChat(Base): class AnthropicChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + import anthropic self.client = anthropic.Anthropic(api_key=key) @@ -1452,6 +1861,8 @@ class AnthropicChat(Base): class GoogleChat(Base): def __init__(self, key, model_name, base_url=None): + super().__init__(key, model_name, base_url=None) + import base64 from google.oauth2 import service_account