diff --git a/api/app.py b/api/app.py index 255c1dbc05..bcf3856c13 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os from werkzeug.exceptions import Unauthorized diff --git a/api/config.py b/api/config.py index 84572ff592..b37a559e02 100644 --- a/api/config.py +++ b/api/config.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os import dotenv diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c06193f91a..87cad07462 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json import logging from datetime import datetime diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 775b3315a8..d95b3d03c2 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index be8d3bf082..f01d2afa03 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union import flask_login from flask import Response, stream_with_context @@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index d29d826b69..0064dbe663 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index fd526b393d..f67fff4b06 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import request from flask_login import current_user diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 8d6231cbac..4e9d9ed9b4 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index d6ced934a7..7aed7da404 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from decimal import Decimal diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 646f672c72..cec022ed58 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import flask_login from flask import current_app, request from flask_restful import Resource, reqparse diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 5a71ccd6e6..2d26d0ecf4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import flask_restful from flask import current_app, request from flask_login import current_user diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 612838a316..3fb6f16cd6 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,4 @@ -# -*- coding:utf-8 -*- from datetime import datetime -from typing import List from flask import request from flask_login import current_user @@ -71,7 +69,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 319b78b6d1..1395963f1d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import uuid from datetime import datetime diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 48d58524bb..d6afee0d63 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 924578f7b4..6406d5b3b0 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- import json import logging +from collections.abc import Generator from datetime import datetime -from typing import Generator, Union +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 8a3fb3a205..34a5904eca 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index e3180bf987..89c4d113a3 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 6e914ef3a4..920d9141ae 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from flask_login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 75c3cdd5c4..47af28425f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 4b18be6dc6..c4afb0b923 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 3c2c806664..fd90be03b1 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import Resource, fields, marshal_with from sqlalchemy import and_ diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 58c2853470..a8d0dd4344 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from functools import wraps from flask import current_app, request diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 519fa25516..a50e4c41a8 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json import logging diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index c511c9778b..b7cfba9d04 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime import pytz diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 1b7d08a879..6ee0188823 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import current_app from flask_login import current_user from flask_restful import Resource, abort, fields, marshal_with, reqparse diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index dbeb712bc2..7b3f08f467 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 1e20265c4b..d5777a330c 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from functools import wraps diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 89d99d66f3..9cd9770c09 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index d47bb089dc..5331f796e7 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import reqparse @@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index d275552d0b..3c157bed99 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import request from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index 56beb56949..eb953d0950 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index a0257b3ed5..d90f536a42 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_restful import fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 0cc63a2ad3..a0d89fe62f 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from functools import wraps diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 6e62c042d4..25492b1143 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b3d7280b64..673aa9ad8c 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index c61995b72c..61d4f8c362 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import reqparse @@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b0d7747d65..c287f2a879 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4566c323a2..9cb3c8f235 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 1a084fe539..e03bdd63bb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import fields, marshal_with, reqparse @@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 188cc41254..92b28d8125 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import uuid from flask import request diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 8ce3a81083..d8e2d59707 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import current_app from flask_restful import fields, marshal_with diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ebf6611784..bdaa476f34 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from functools import wraps from flask import request diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/agent/agent/agent_llm_callback.py index 8331731200..5ec549de8e 100644 --- a/api/core/agent/agent/agent_llm_callback.py +++ b/api/core/agent/agent/agent_llm_callback.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.model_runtime.callbacks.base_callback import Callback @@ -17,7 +17,7 @@ class AgentLLMCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -38,7 +38,7 @@ class AgentLLMCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -58,7 +58,7 @@ class AgentLLMCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -80,7 +80,7 @@ class AgentLLMCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index b25ab2d88a..9c0f9c5b36 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.message_entities import PromptMessage @@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class CalcTokenMixin: - def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int: + def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: """ Got the rest tokens available for the model after excluding messages tokens and completion max tokens diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 201421910d..eb594c3d21 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message @@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def real_plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): async def aplan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index 3dafa4517b..1f2d5f24b3 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message @@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), @@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def return_stopped_response( self, early_stopping_method: str, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any, ) -> AgentFinish: try: @@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi except ValueError: return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") - def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]: + def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 rest_tokens = self.get_message_rest_tokens( self.model_config, @@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi return new_messages def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) return chain.predict(summary=existing_summary, new_lines=new_lines) - def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int: + def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 9d36e01d7c..e104bb01f9 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,5 +1,6 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent @@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -180,7 +181,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -213,8 +214,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 03fea8c27d..e1be624204 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,5 +1,6 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent @@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " "I don't know how to respond to that."}, "") - def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): + def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): if len(intermediate_steps) >= 2 and self.summary_model_config: should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_messages = [AIMessage(content=observation) @@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return self.get_full_inputs([intermediate_steps[-1]], **kwargs) def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -227,7 +228,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -260,8 +261,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, agent_llm_callback: Optional[AgentLLMCallback] = None, **kwargs: Any, ) -> Agent: diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index 457cae8289..2b8ddc5d4e 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -1,5 +1,6 @@ import time -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ( @@ -84,7 +85,7 @@ class AppRunner: return rest_tokens def recale_llm_max_tokens(self, model_config: ModelConfigEntity, - prompt_messages: List[PromptMessage]): + prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -126,7 +127,7 @@ class AppRunner: query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ - -> Tuple[List[PromptMessage], Optional[List[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -295,7 +296,7 @@ class AppRunner: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 39f51ee1b6..20e4bc7992 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -1,7 +1,8 @@ import json import logging import time -from typing import Generator, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast from pydantic import BaseModel @@ -118,7 +119,7 @@ class GenerateTaskPipeline: } self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -201,7 +202,7 @@ class GenerateTaskPipeline: data = self._error_to_stream_response_data(self._handle_error(event)) yield self._yield_response(data) break - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -353,7 +354,7 @@ class GenerateTaskPipeline: yield self._yield_response(response) - elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)): + elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index 392425ed8e..b2098344c8 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class ModerationRule(BaseModel): type: str - config: Dict[str, Any] + config: dict[str, Any] class OutputModerationHandler(BaseModel): diff --git a/api/core/application_manager.py b/api/core/application_manager.py index b718cefab6..d2f4326b4f 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -2,7 +2,8 @@ import json import logging import threading import uuid -from typing import Any, Generator, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Any, Optional, Union, cast from flask import Flask, current_app from pydantic import ValidationError @@ -585,7 +586,7 @@ class ApplicationManager: return AppOrchestrationConfigEntity(**properties) def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> Tuple[Conversation, Message]: + -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 75a56d6706..9590a1e726 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -1,7 +1,8 @@ import queue import time +from collections.abc import Generator from enum import Enum -from typing import Any, Generator +from typing import Any from sqlalchemy.orm import DeclarativeMeta diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index f9347198dc..1d25b8ab69 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.callbacks.base import BaseCallbackHandler @@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._message_agent_thought = None @property - def agent_loops(self) -> List[AgentLoop]: + def agent_loops(self) -> list[AgentLoop]: return self._agent_loops def clear_agent_loops(self) -> None: @@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: pass def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: pass @@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index ae77bf6cd1..3fed7d0ad5 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_start( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], ) -> None: """Do nothing.""" print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) @@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_end( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], tool_outputs: str, ) -> None: """If not the final action, print out observation.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 63c9bbe416..7c8a3ce478 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ -from typing import List from langchain.schema import Document @@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_query) db.session.commit() - def on_tool_end(self, documents: List[Document]) -> None: + def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: doc_id = document.metadata['doc_id'] @@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler: db.session.commit() - def return_retriever_resource_info(self, resource: List): + def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" if resource and len(resource) > 0: for item in resource: diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py index 9f586d2c9b..1f95471afb 100644 --- a/api/core/callback_handler/std_out_callback_handler.py +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: print_text("\n[on_chat_model_start]\n", color='blue') @@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text(str(sub_message) + "\n", color='blue') def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: """Print out the prompts.""" print_text("\n[on_llm_start]\n", color='blue') @@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain.""" chain_type = serialized['id'][-1] print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') @@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py index a5d160c99e..86fb156292 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/chain/llm_chain.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain import LLMChain as LCLLMChain from langchain.callbacks.manager import CallbackManagerForChainRun @@ -16,12 +16,12 @@ class LLMChain(LCLLMChain): model_config: ModelConfigEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") - parameters: Dict[str, Any] = {} + parameters: dict[str, Any] = {} agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, - input_list: List[Dict[str, Any]], + input_list: list[dict[str, Any]], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> LLMResult: """Generate LLM result from inputs.""" diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index af0fb1d35a..4a6eb3654d 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,6 +1,6 @@ import tempfile from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import requests from flask import current_app @@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM class FileExtractor: @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]: + def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(upload_file.key).suffix file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" @@ -37,7 +37,7 @@ class FileExtractor: return cls.load_from_file(file_path, return_text, upload_file, is_automatic) @classmethod - def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]: + def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: response = requests.get(url, headers={ "User-Agent": USER_AGENT }) @@ -53,7 +53,7 @@ class FileExtractor: @classmethod def load_from_file(cls, file_path: str, return_text: bool = False, upload_file: Optional[UploadFile] = None, - is_automatic: bool = False) -> Union[List[Document], str]: + is_automatic: bool = False) -> Union[list[Document], str]: input_file = Path(file_path) delimiter = '\n' file_extension = input_file.suffix.lower() diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/data_loader/loader/csv_loader.py index a4d4ed2b39..ce252c157e 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/data_loader/loader/csv_loader.py @@ -1,6 +1,6 @@ import csv import logging -from typing import Dict, List, Optional +from typing import Optional from langchain.document_loaders import CSVLoader as LCCSVLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader): self, file_path: str, source_column: Optional[str] = None, - csv_args: Optional[Dict] = None, + csv_args: Optional[dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = True, ): @@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader): self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: """Load data into document objects.""" try: with open(self.file_path, newline="", encoding=self.encoding) as csvfile: diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py index f5f6b2d69c..cddb298547 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/data_loader/loader/excel.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: data = [] keys = [] wb = load_workbook(filename=self._file_path, read_only=True) diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py index 414975007b..6a9b48a5b2 100644 --- a/api/core/data_loader/loader/html.py +++ b/api/core/data_loader/loader/html.py @@ -1,5 +1,4 @@ import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: diff --git a/api/core/data_loader/loader/markdown.py b/api/core/data_loader/loader/markdown.py index 545c6b10ed..ecbc6d548f 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/data_loader/loader/markdown.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader): self._encoding = encoding self._autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: tups = self.parse_tups(self._file_path) documents = [] for header, value in tups: @@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader): return documents - def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: List[Tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[Optional[str], str]] = [] lines = markdown_text.split("\n") current_header = None @@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: """Parse file into tuples.""" content = "" try: - with open(filepath, "r", encoding=self._encoding) as f: + with open(filepath, encoding=self._encoding) as f: content = f.read() except UnicodeDecodeError as e: if self._autodetect_encoding: diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index 9f9198c3ce..f8d8837683 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import requests from flask import current_app @@ -67,7 +67,7 @@ class NotionLoader(BaseLoader): document_model=document_model ) - def load(self) -> List[Document]: + def load(self) -> list[Document]: self.update_last_edited_time( self._document_model ) @@ -78,7 +78,7 @@ class NotionLoader(BaseLoader): def _load_data_as_documents( self, notion_obj_id: str, notion_page_type: str - ) -> List[Document]: + ) -> list[Document]: docs = [] if notion_page_type == 'database': # get all the pages in the database @@ -94,8 +94,8 @@ class NotionLoader(BaseLoader): return docs def _get_notion_database_data( - self, database_id: str, query_dict: Dict[str, Any] = {} - ) -> List[Document]: + self, database_id: str, query_dict: dict[str, Any] = {} + ) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -149,12 +149,12 @@ class NotionLoader(BaseLoader): return database_content_list - def _get_notion_block_data(self, page_id: str) -> List[str]: + def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] cur_block_id = page_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -216,7 +216,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -280,7 +280,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while not done: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -346,7 +346,7 @@ class NotionLoader(BaseLoader): else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py index 881d0026b5..a3452b367b 100644 --- a/api/core/data_loader/loader/pdf.py +++ b/api/core/data_loader/loader/pdf.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from langchain.document_loaders import PyPDFium2Loader from langchain.document_loaders.base import BaseLoader @@ -28,7 +28,7 @@ class PdfLoader(BaseLoader): self._file_path = file_path self._upload_file = upload_file - def load(self) -> List[Document]: + def load(self) -> list[Document]: plaintext_file_key = '' plaintext_file_exists = False if self._upload_file: diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/data_loader/loader/unstructured/unstructured_eml.py index 26e0ce8cda..2fa3aac133 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_eml.py @@ -1,6 +1,5 @@ import base64 import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.email import partition_email elements = partition_email(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_markdown.py b/api/core/data_loader/loader/unstructured/unstructured_markdown.py index cf6e7c9c8a..036a2afd25 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_markdown.py +++ b/api/core/data_loader/loader/unstructured/unstructured_markdown.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.md import partition_md elements = partition_md(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/data_loader/loader/unstructured/unstructured_msg.py index 5a9813237e..495be328ed 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/data_loader/loader/unstructured/unstructured_msg.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.msg import partition_msg elements = partition_msg(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/data_loader/loader/unstructured/unstructured_ppt.py index 9b1e6b5abf..cfac91cc7b 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/data_loader/loader/unstructured/unstructured_ppt.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.ppt import partition_ppt elements = partition_ppt(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/data_loader/loader/unstructured/unstructured_pptx.py index 0eecee9ffe..41e3bfcb54 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/data_loader/loader/unstructured/unstructured_pptx.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/data_loader/loader/unstructured/unstructured_text.py index dd684b37f2..09d14fdb17 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/data_loader/loader/unstructured/unstructured_text.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.text import partition_text elements = partition_text(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/data_loader/loader/unstructured/unstructured_xml.py index 0ddbb74b9c..cca6e1b0b7 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_xml.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.xml import partition_xml elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 77a5dde9ed..556b3aceda 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Optional, cast from langchain.schema import Document from sqlalchemy import func @@ -22,10 +23,10 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": return cls(**config_dict) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize to dict.""" return { "dataset_id": self._dataset.id, @@ -40,7 +41,7 @@ class DatasetDocumentStore: return self._user_id @property - def docs(self) -> Dict[str, Document]: + def docs(self) -> dict[str, Document]: document_segments = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self._dataset.id ).all() diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4f7b3a1530..a86afd817a 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import List, Optional, cast +from typing import Optional, cast import numpy as np from langchain.embeddings.base import Embeddings @@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings): self._model_instance = model_instance self._user = user - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" text_embeddings = [] try: @@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index fd61647635..b83ae0c8e7 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,8 +1,9 @@ import datetime import json import logging +from collections.abc import Iterator from json import JSONDecodeError -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Optional from pydantic import BaseModel @@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel): if self.provider.provider_credential_schema else [] ) - def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel): return None def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> Tuple[ProviderModel, dict]: + -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel): Model class for provider configuration dict. """ tenant_id: str - configurations: Dict[str, ProviderConfiguration] = {} + configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) @@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel): return all_models - def to_list(self) -> List[ProviderConfiguration]: + def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 6b27062f13..c19aaefe9e 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -61,7 +61,7 @@ class Extensible: builtin_file_path = os.path.join(subdir_path, '__builtin__') if os.path.exists(builtin_file_path): - with open(builtin_file_path, 'r', encoding='utf-8') as f: + with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) if (extension_name + '.py') not in file_names: @@ -93,7 +93,7 @@ class Extensible: json_path = os.path.join(subdir_path, 'schema.json') json_data = {} if os.path.exists(json_path): - with open(json_path, 'r', encoding='utf-8') as f: + with open(json_path, encoding='utf-8') as f: json_data = json.load(f) extensions[extension_name] = ModuleExtension( diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 4c0bde989a..c62028eaf0 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -2,7 +2,7 @@ import json import logging from datetime import datetime from mimetypes import guess_extension -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager @@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner): message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[List[PromptMessage]] = None, + prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, model_instance: ModelInstance = None @@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner): return app_orchestration_config - def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: + def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ @@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result - def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]: + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ convert tool to prompt message tool """ @@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner): return prompt_tool - def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]: + def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ Extract tool response binary """ @@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result - def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]: + def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]: """ Create message file @@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: List[str] + tool_name: str, tool_input: str, messages_ids: list[str] ) -> MessageAgentThought: """ Create agent thought @@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner): thought: str, observation: str, answer: str, - messages_ids: List[str], + messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought: """ Save agent thought @@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner): db.session.commit() - def get_history_prompt_messages(self) -> List[PromptMessage]: + def get_history_prompt_messages(self) -> list[PromptMessage]: """ Get history prompt messages """ @@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner): return self.history_prompt_messages - def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]: + def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ Transform tool message into agent thought """ diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 5464069838..b8d08bb5d3 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -1,6 +1,7 @@ import json import re -from typing import Dict, Generator, List, Literal, Union +from collections.abc import Generator +from typing import Literal, Union from core.application_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit @@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, query: str, - inputs: Dict[str, str], + inputs: dict[str, str], ) -> Union[Generator, LLMResult]: """ Run Cot agent application @@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): app_orchestration_config = self.app_orchestration_config self._repack_app_orchestration_config(app_orchestration_config) - agent_scratchpad: List[AgentScratchpadUnit] = [] + agent_scratchpad: list[AgentScratchpadUnit] = [] # check model mode if self.app_orchestration_config.model_config.mode == "completion": @@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if not next_iteration.find("{{observation}}") >= 0: raise ValueError("{{observation}} is required in next_iteration") - def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: + def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ convert agent scratchpad list to str """ @@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): return result def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], - prompt_messages: List[PromptMessage], - tools: List[PromptMessageTool], - agent_scratchpad: List[AgentScratchpadUnit], + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + agent_scratchpad: list[AgentScratchpadUnit], agent_prompt_message: AgentPromptEntity, instruction: str, input: str, - ) -> List[PromptMessage]: + ) -> list[PromptMessage]: """ organize chain of thought prompt messages, a standard prompt message is like: Respond to the human as helpfully and accurately as possible. diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index b0e3d3a7af..7ad9d7bd2a 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Dict, Generator, List, Tuple, Union +from collections.abc import Generator +from typing import Any, Union from core.application_queue_manager import PublishFrom from core.features.assistant_base_runner import BaseAssistantApplicationRunner @@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ) # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): # continue to run until there is not any tool call function_call_state = True - agent_thoughts: List[MessageAgentThought] = [] + agent_thoughts: list[MessageAgentThought] = [] llm_usage = { 'usage': None } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): callbacks=[], ) - tool_calls: List[Tuple[str, str, Dict[str, Any]]] = [] + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response response = '' @@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return tool_calls - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index 159428aad4..488a8ca8d0 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Optional, cast from langchain.tools import BaseTool @@ -96,7 +96,7 @@ class DatasetRetrievalFeature: return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[List[BaseTool]]: + -> Optional[list[BaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 33154d8389..7f23c8ed72 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -2,7 +2,7 @@ import concurrent import json import logging from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple +from typing import Optional from flask import Flask, current_app @@ -62,7 +62,7 @@ class ExternalDataFetchFeature: app_id: str, external_data_tool: ExternalDataVariableEntity, inputs: dict, - query: str) -> Tuple[Optional[str], Optional[str]]: + query: str) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/features/moderation.py b/api/core/features/moderation.py index 9735fad0e7..a9d65f56e8 100644 --- a/api/core/features/moderation.py +++ b/api/core/features/moderation.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple from core.entities.application_entities import AppOrchestrationConfigEntity from core.moderation.base import ModerationAction, ModerationException @@ -13,7 +12,7 @@ class ModerationFeature: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index ce783d8fbb..1b7b8b87da 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Optional, Union import requests @@ -15,8 +15,8 @@ class MessageFileParser: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig, - user: Union[Account, EndUser]) -> List[FileObj]: + def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg @@ -96,7 +96,7 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]: + def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: """ transform message files @@ -110,8 +110,8 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: List[Union[Dict, MessageFile]], - file_upload_config: dict) -> Dict[FileType, List[FileObj]]: + def _to_file_objs(self, files: list[Union[dict, MessageFile]], + file_upload_config: dict) -> dict[FileType, list[FileObj]]: """ transform files to file objs @@ -119,7 +119,7 @@ class MessageFileParser: :param file_upload_config: :return: """ - type_file_objs: Dict[FileType, List[FileObj]] = { + type_file_objs: dict[FileType, list[FileObj]] = { # Currently only support image FileType.IMAGE: [] } diff --git a/api/core/index/base.py b/api/core/index/base.py index 1dc7cfdcc6..f8eb1a134a 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any from langchain.schema import BaseRetriever, Document @@ -53,7 +53,7 @@ class BaseIndex(ABC): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def delete(self) -> None: diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py index db9fd027a0..df93a1903a 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -1,5 +1,4 @@ import re -from typing import Set import jieba from jieba.analyse import default_tfidf @@ -12,7 +11,7 @@ class JiebaKeywordTableHandler: def __init__(self): default_tfidf.stop_words = STOPWORDS - def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" keywords = jieba.analyse.extract_tags( sentence=text, @@ -21,7 +20,7 @@ class JiebaKeywordTableHandler: return set(self._expand_tokens_with_subtokens(keywords)) - def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: + def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" results = set() for token in tokens: diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 9ad8b8d64e..8bf0b13344 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -1,6 +1,6 @@ import json from collections import defaultdict -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain.schema import BaseRetriever, Document from pydantic import BaseModel, Extra, Field @@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: keyword_table = self._get_dataset_keyword_table() search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} @@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex): keywords = keyword_table_handler.extract_keywords(query) # go through text chunks in order of most matching keywords - chunk_indices_count: Dict[str, int] = defaultdict(int) + chunk_indices_count: dict[str, int] = defaultdict(int) keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] for keyword in keywords: for node_id in keyword_table[keyword]: @@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex): return sorted_chunk_indices[: k] - def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]): + def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): document_segment = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id @@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex): document_segment.keywords = keywords db.session.commit() - def create_segment_keywords(self, node_id: str, keywords: List[str]): + def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) @@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex): keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) - def update_segment_keywords_index(self, node_id: str, keywords: List[str]): + def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) self._save_dataset_keyword_table(keyword_table) @@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def get_relevant_documents(self, query: str) -> List[Document]: + def get_relevant_documents(self, query: str) -> list[Document]: """Get documents relevant for a query. Args: @@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): """ return self.index.search(query, **self.search_kwargs) - async def aget_relevant_documents(self, query: str) -> List[Document]: + async def aget_relevant_documents(self, query: str) -> list[Document]: raise NotImplementedError("KeywordTableRetriever does not support async") diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index b9b8e6d3dc..36aa1917a6 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -1,7 +1,7 @@ import json import logging from abc import abstractmethod -from typing import Any, List, cast +from typing import Any, cast from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document @@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex): def search_by_full_text_index( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index a0b6f5d207..a18cf35a27 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,4 +1,4 @@ -from typing import Any, List, cast +from typing import Any, cast from langchain.embeddings.base import Embeddings from langchain.schema import Document @@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex): ], )) - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz doesn't support bm25 search return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index f182c4c0e1..046260d2f8 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -1,5 +1,5 @@ import os -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import qdrant_client from langchain.embeddings.base import Embeddings @@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 8af3c5926b..72a74a039f 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import requests import weaviate @@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 1f36362a8b..a14001d04e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,7 @@ import re import threading import time import uuid -from typing import List, Optional, cast +from typing import Optional, cast from flask import Flask, current_app from flask_login import current_user @@ -40,7 +40,7 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() - def run(self, dataset_documents: List[DatasetDocument]): + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: try: @@ -238,7 +238,7 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, + def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, indexing_technique: str = 'economy') -> dict: """ @@ -494,7 +494,7 @@ class IndexingRunner: "preview": preview_texts } - def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: + def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -526,7 +526,7 @@ class IndexingRunner: ) # replace doc id to document model id - text_docs = cast(List[Document], text_docs) + text_docs = cast(list[Document], text_docs) for text_doc in text_docs: # remove invalid symbol text_doc.page_content = self.filter_string(text_doc.page_content) @@ -540,7 +540,7 @@ class IndexingRunner: text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) # Unicode U+FFFE - text = re.sub(u'\uFFFE', '', text) + text = re.sub('\uFFFE', '', text) return text def _get_splitter(self, processing_rule: DatasetProcessRule, @@ -577,9 +577,9 @@ class IndexingRunner: return character_splitter - def _step_split(self, text_docs: List[Document], splitter: TextSplitter, + def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> List[Document]: + -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -624,9 +624,9 @@ class IndexingRunner: return documents - def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, + def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> List[Document]: + document_form: str, document_language: str) -> list[Document]: """ Split the text documents into nodes. """ @@ -699,8 +699,8 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> List[Document]: + def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, + processing_rule: DatasetProcessRule) -> list[Document]: """ Split the text documents into nodes. """ @@ -770,7 +770,7 @@ class IndexingRunner: for q, a in matches if q and a ] - def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: + def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: """ Build the index for the document. """ @@ -877,7 +877,7 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset): + def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 68df0ac31a..8e36ab7ee8 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,4 +1,5 @@ -from typing import IO, Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderModelBundle from core.errors.error import ProviderTokenNotInitError @@ -47,7 +48,7 @@ class ModelInstance: return credentials def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 58150ef4da..51af9786fd 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Optional +from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,7 +23,7 @@ class Callback(ABC): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -42,7 +42,7 @@ class Callback(ABC): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -62,7 +62,7 @@ class Callback(ABC): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -82,7 +82,7 @@ class Callback(ABC): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 4864858445..0406853b88 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,7 @@ import json import logging import sys -from typing import List, Optional +from typing import Optional from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class LoggingCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -60,7 +60,7 @@ class LoggingCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -81,7 +81,7 @@ class LoggingCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -113,7 +113,7 @@ class LoggingCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index b39427dccd..856f4ce7d1 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,8 +1,7 @@ -from typing import Dict from core.model_runtime.entities.model_entities import DefaultParameterName -PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = { +PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { 'label': { 'en_US': 'Temperature', diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index eb811ab224..a9f7a539e2 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -153,7 +153,7 @@ class AIModel(ABC): # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} @@ -161,7 +161,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, 'r', encoding='utf-8') as f: + with open(model_schema_yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) new_parameter_rules = [] diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 173b4dcab7..1f7edd245f 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -3,7 +3,8 @@ import os import re import time from abc import abstractmethod -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback @@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel): def invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ @@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: """ Invoke result generator @@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel): @abstractmethod def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel): """ raise NotImplementedError - def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: + def enforce_stop_tokens(self, text: str, stop: list[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0] @@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel): def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger before invoke callbacks @@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel): def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger new chunk callbacks @@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel): def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger after invoke callbacks @@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel): def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger invoke error callbacks @@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel): raise ValueError( f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") elif parameter_rule.type == ParameterType.FLOAT: - if not isinstance(parameter_value, (float, int)): + if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") # validate parameter value precision diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index f3d71670f1..97ce07d35f 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,7 +1,6 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict import yaml @@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel class ModelProvider(ABC): provider_schema: ProviderEntity = None - model_instance_map: Dict[str, AIModel] = {} + model_instance_map: dict[str, AIModel] = {} @abstractmethod def validate_provider_credentials(self, credentials: dict) -> None: @@ -47,7 +46,7 @@ class ModelProvider(ABC): yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_data = {} if os.path.exists(yaml_path): - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) try: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 3f689a724d..c743708896 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import anthropic from anthropic import Anthropic, Stream @@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 1bab34edd6..4b89adaa49 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,6 +1,7 @@ import copy import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken from openai import AzureOpenAI, Stream @@ -34,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: @@ -121,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return ai_model_entity.entity if ai_model_entity else None def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -239,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -537,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage], + def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 606a898db5..e073bef014 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import base64 import copy import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken @@ -149,7 +149,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): @staticmethod def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 4562bb2be7..7549b2fb60 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -1,7 +1,7 @@ import re -class BaichuanTokenizer(object): +class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: return len(re.findall(r'[\u4e00-\u9fa5]', text)) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 46ba0cffaf..ae73c1735a 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,7 +1,8 @@ +from collections.abc import Generator from enum import Enum from hashlib import md5 from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import post @@ -24,10 +25,10 @@ class BaichuanMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -37,7 +38,7 @@ class BaichuanMessage: self.content = content self.role = role -class BaichuanModel(object): +class BaichuanModel: api_key: str secret_key: str @@ -106,9 +107,9 @@ class BaichuanModel(object): message.stop_reason = stop_reason yield message - def _build_parameters(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any]) \ - -> Dict[str, Any]: + def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any]) \ + -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': prompt_messages = [] for message in messages: @@ -139,7 +140,7 @@ class BaichuanModel(object): else: raise BadRequestError(f"Unknown model: {model}") - def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]: + def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': # there is no secret key for turbo api return { @@ -153,8 +154,8 @@ class BaichuanModel(object): def _calculate_md5(self, input_string): return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any], timeout: int) \ + def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any], timeout: int) \ -> Union[Generator, BaichuanMessage]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index a7c6119d10..707355fa7a 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, cast +from collections.abc import Generator +from typing import cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -33,7 +34,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor class BaichuanLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -43,7 +44,7 @@ class BaichuanLarguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return BaichuanTokenizer._get_num_tokens(text) @@ -107,7 +108,7 @@ class BaichuanLarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: if tools is not None and len(tools) > 0: raise InvokeBadRequestError("Baichuan model doesn't support tools") diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 5020c58996..da4ba55881 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ import time from json import dumps -from typing import Optional, Tuple +from typing import Optional from requests import post @@ -84,7 +84,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): return result def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> Tuple[list[list[float]], int]: + -> tuple[list[list[float]], int]: """ Embed given texts diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 7a2faae895..c6aaa24ade 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import boto3 from botocore.config import Config @@ -37,7 +38,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -159,7 +160,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -181,7 +182,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True): + def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): """ Create payload for bedrock api call depending on model provider """ @@ -231,7 +232,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index fd2bcd5ec3..12dc75aece 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Generator from os.path import join -from typing import Generator, List, Optional, cast +from typing import Optional, cast from httpx import Timeout from openai import ( @@ -45,7 +46,7 @@ logger = logging.getLogger(__name__) class ChatGLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -138,7 +139,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -394,7 +395,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: List[PromptMessage], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 95d3252b11..667ba4c78c 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import cohere from cohere.responses import Chat, Generations @@ -38,7 +39,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -138,7 +139,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -264,7 +265,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): break def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -306,7 +307,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, - prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ + prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ -> LLMResult: """ Handle llm chat response @@ -352,7 +353,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, prompt_messages: list[PromptMessage], - stop: Optional[List[str]] = None) -> Generator: + stop: Optional[list[str]] = None) -> Generator: """ Handle llm chat stream response @@ -427,7 +428,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): index += 1 def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> Tuple[str, list[dict]]: + -> tuple[str, list[dict]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -495,7 +496,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return response.length - def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: """Calculate num tokens Cohere model.""" messages = [self._convert_prompt_message_to_dict(m) for m in messages] message_strs = [f"{message['role']}: {message['message']}" for message in messages] diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index fda8b27de4..5eec721841 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Optional import cohere import numpy as np @@ -168,7 +168,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index e376e72c07..686761ab5f 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai @@ -34,7 +35,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -103,7 +104,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 381d29c7e5..f43a8aedaf 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from huggingface_hub import InferenceClient from huggingface_hub.hf_api import HfApi @@ -29,7 +30,7 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = InferenceClient(token=credentials['huggingfacehub_api_token']) diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 8d571d20b1..694f5891f9 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from os.path import join -from typing import Generator, List, cast +from typing import cast from httpx import Timeout from openai import ( @@ -52,7 +53,7 @@ from core.model_runtime.utils import helper class LocalAILarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -63,7 +64,7 @@ class LocalAILarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for baichuan model LocalAI does not supports @@ -241,7 +242,7 @@ class LocalAILarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) @@ -346,7 +347,7 @@ class LocalAILarguageModel(LargeLanguageModel): return message_dict - def _convert_prompt_message_to_completion_prompts(self, messages: List[PromptMessage]) -> str: + def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str: """ Convert PromptMessage to completion prompts """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index ee73005bd7..6c41e0d2a5 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -14,13 +15,13 @@ from core.model_runtime.model_providers.minimax.llm.errors import ( from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletion(object): +class MinimaxChatCompletion: """ Minimax Chat Completion API """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 2497a9d7b8..81ea2e165e 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -14,14 +15,14 @@ from core.model_runtime.model_providers.minimax.llm.errors import ( from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletionPro(object): +class MinimaxChatCompletionPro: """ Minimax Chat Completion Pro API, supports function calling however, we do not have enough time and energy to implement it, but the parameters are reserved """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index bc65e756eb..cc88d15736 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -42,7 +42,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -79,7 +79,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for minimax model @@ -94,7 +94,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 6229312445..b33a7ca9ac 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict +from typing import Any class MinimaxMessage: @@ -11,11 +11,11 @@ class MinimaxMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - function_call: Dict[str, Any] = None + function_call: dict[str, Any] = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: return { 'sender_type': 'BOT', diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e1e74ea806..185ff62711 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -220,7 +220,7 @@ class ModelProviderFactory: # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 40618b7fb4..5db3e2827b 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -8,7 +9,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 848ac76d33..e4388699e3 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -1,8 +1,9 @@ import json import logging import re +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests @@ -51,7 +52,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -131,7 +132,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -398,7 +399,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return message_dict - def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: """ Calculate num tokens. diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 92a370e047..2a1137d443 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -35,7 +36,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -215,7 +216,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return ai_model_entities def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -366,7 +367,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -706,7 +707,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 28ab5c30ff..e23a2edf87 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ import base64 import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken @@ -162,7 +162,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): raise CredentialsValidateFailedError(str(ex)) def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 9a26f3dc08..ae856c5ce9 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,7 +1,8 @@ import json import logging +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests @@ -46,7 +47,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -245,7 +246,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, \ user: Optional[str] = None) -> Union[LLMResult, Generator]: """ @@ -567,7 +568,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Approximate num tokens with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 3491f107ab..8ea5819bde 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -40,7 +40,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -77,7 +77,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for OpenLLM model it's a generate model, so we just join them by spe @@ -87,7 +87,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 06453cb3f8..43258d1e5e 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,6 +1,7 @@ +from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -19,10 +20,10 @@ class OpenLLMGenerateMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -33,10 +34,10 @@ class OpenLLMGenerateMessage: self.role = role -class OpenLLMGenerate(object): +class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: Dict[str, Any], - stop: List[str], prompt_messages: List[OpenLLMGenerateMessage], user: str, + self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], + stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: raise InvalidAuthenticationError('Invalid server URL') diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index ce69c67984..ee2de85607 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from replicate import Client as ReplicateClient from replicate.exceptions import ReplicateError @@ -29,7 +30,7 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: version = credentials['model_version'] diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 6dfa1e3a6b..65beae517c 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,5 +1,6 @@ import threading -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -27,7 +28,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -86,7 +87,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -244,7 +245,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index 89198fe4b0..b312d99b1c 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -14,7 +15,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -27,7 +28,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_client.py b/api/core/model_runtime/model_providers/tongyi/llm/_client.py index 2aab69af7a..cfe33558e1 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/_client.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/_client.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import Tongyi @@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult class EnhanceTongyi(Tongyi): @property - def _default_params(self) -> Dict[str, Any]: + def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { "top_p": self.top_p, @@ -19,13 +19,13 @@ class EnhanceTongyi(Tongyi): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations = [] - params: Dict[str, Any] = { + params: dict[str, Any] = { **{"model": self.model_name}, **self._default_params, **kwargs, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 8aac4412fd..7ae8b87764 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from dashscope import get_tokenizer from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -38,7 +39,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -100,7 +101,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -268,7 +269,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index af04eca59b..81868aeed1 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,8 +1,9 @@ +from collections.abc import Generator from datetime import datetime, timedelta from enum import Enum from json import dumps, loads from threading import Lock -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -16,7 +17,7 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( ) # map api_key to access_token -baidu_access_tokens: Dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} baidu_access_tokens_lock = Lock() class BaiduAccessToken: @@ -105,10 +106,10 @@ class ErnieMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -118,7 +119,7 @@ class ErnieMessage: self.content = content self.role = role -class ErnieBotModel(object): +class ErnieBotModel: api_bases = { 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', @@ -138,9 +139,9 @@ class ErnieBotModel(object): self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: List[ErnieMessage], - parameters: Dict[str, Any], timeout: int, tools: List[PromptMessageTool], \ - stop: List[str], user: str) \ + def generate(self, model: str, stream: bool, messages: list[ErnieMessage], + parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ + stop: list[str], user: str) \ -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters @@ -216,11 +217,11 @@ class ErnieBotModel(object): token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) return token.access_token - def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]: + def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str]) -> None: + def _check_parameters(self, model: str, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str]) -> None: if model not in self.api_bases: raise BadRequestError(f'Invalid model: {model}') @@ -241,16 +242,16 @@ class ErnieBotModel(object): if len(s) > 20: raise BadRequestError('stop item should not exceed 20 characters.') - def _build_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str], user: str) -> Dict[str, Any]: + def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], tools: List[PromptMessageTool], - stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], tools: list[PromptMessageTool], + stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) % 2 == 0: raise BadRequestError('The number of messages should be odd.') if messages[0].role == 'function': @@ -260,9 +261,9 @@ class ErnieBotModel(object): TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) == 0: raise BadRequestError('The number of messages should not be zero.') diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index b13e340d91..51b3c97497 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, cast +from collections.abc import Generator +from typing import cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -32,7 +33,7 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( class ErnieBotLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -43,7 +44,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -78,7 +79,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: instance = ErnieBotModel( api_key=credentials['api_key'], diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 7da1b00651..83c003d051 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, Iterator, List, cast +from collections.abc import Generator, Iterator +from typing import cast from openai import ( APIConnectionError, @@ -62,7 +63,7 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ invoke LLM @@ -131,7 +132,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -359,7 +360,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM @@ -404,7 +405,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } for tool in tools ] - if isinstance(xinference_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)): + if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 089ffd691f..24a91af62c 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,22 +1,21 @@ from os import path from threading import Lock from time import time -from typing import List from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session -class XinferenceModelExtraParameter(object): +class XinferenceModelExtraParameter: model_format: str model_handle_type: str - model_ability: List[str] + model_ability: list[str] max_tokens: int = 512 context_length: int = 2048 support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], support_function_call: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 6d1f462d0f..c62422dfb0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -23,7 +24,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -89,7 +90,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _generate(self, model: str, credentials_kwargs: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -119,7 +120,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): prompt_messages = prompt_messages[1:] # resolve zhipuai model not support system message and user message, assistant message must be in sequence - new_prompt_messages: List[PromptMessage] = [] + new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: @@ -275,7 +276,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: llm response """ text = '' - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: @@ -335,7 +336,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: if tool_call.type == 'function': assistant_tool_calls.append( @@ -409,7 +410,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 30c373729a..0f9fecfc72 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import Optional from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -81,7 +81,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: + def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to ZhipuAI's embedding endpoint. Args: @@ -101,7 +101,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. Args: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 23fd968f30..29b1746351 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -1,7 +1,8 @@ from __future__ import annotations import os -from typing import Mapping, Union +from collections.abc import Mapping +from typing import Union import httpx from httpx import Timeout diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index 16c4b54f1a..dab6dac5fe 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -15,7 +14,7 @@ if TYPE_CHECKING: class AsyncCompletions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) @@ -29,8 +28,8 @@ class AsyncCompletions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], List[List[int]], None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index e5bb8cdf68..5c4ed4d1ba 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -17,7 +16,7 @@ if TYPE_CHECKING: class Completions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( @@ -31,8 +30,8 @@ class Completions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], object, None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index d5db469de4..35d54592fd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import httpx @@ -14,13 +14,13 @@ if TYPE_CHECKING: class Embeddings(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( self, *, - input: Union[str, List[str], List[int], List[List[int]]], + input: Union[str, list[str], list[int], list[list[int]]], model: Union[str], encoding_format: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 7796b778a3..5deb8d08f3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -18,7 +18,7 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index ead6cdae2f..b860de192a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -17,7 +17,7 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index ce852a48c3..3201426dfa 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class Images(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def generations( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index f3dde8461c..b7cf6bb7fd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -1,21 +1,22 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from os import PathLike -from typing import IO, TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Type, TypeVar, Union +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union import pydantic -from typing_extensions import Literal, override +from typing_extensions import override Query = Mapping[str, object] Body = object AnyMapping = Mapping[str, object] PrimitiveData = Union[str, int, float, bool, None] -Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] +Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) _T = TypeVar("_T") if TYPE_CHECKING: - NoneType: Type[None] + NoneType: type[None] else: NoneType = type(None) @@ -74,7 +75,7 @@ Headers = Mapping[str, Union[str, Omit]] ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", + bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", ) # for user input files @@ -85,21 +86,21 @@ else: FileTypes = Union[ FileContent, # file content - Tuple[str, FileContent], # (filename, file) - Tuple[str, FileContent, str], # (filename, file , content_type) - Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, FileContent], # (filename, file) + tuple[str, FileContent, str], # (filename, file , content_type) + tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] # for httpx client supported files HttpxFileContent = Union[bytes, IO[bytes]] HttpxFileTypes = Union[ FileContent, # file content - Tuple[str, HttpxFileContent], # (filename, file) - Tuple[str, HttpxFileContent, str], # (filename, file , content_type) - Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, HttpxFileContent], # (filename, file) + tuple[str, HttpxFileContent, str], # (filename, file , content_type) + tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py index e41ede128a..0796bfe11c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -2,14 +2,14 @@ from __future__ import annotations import io import os +from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Mapping, Sequence from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles def is_file_content(obj: object) -> bool: - return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) + return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) def _transform_file(file: FileTypes) -> HttpxFileTypes: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 5227d20615..e13d2b0233 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import inspect -from typing import Any, Mapping, Type, Union, cast +from collections.abc import Mapping +from typing import Any, Union, cast import httpx import pydantic @@ -140,7 +140,7 @@ class HttpClient: for k, v in value.items(): items.extend(self._object_to_formfata(f"{key}[{k}]", v)) return items - if isinstance(value, (list, tuple)): + if isinstance(value, list | tuple): for v in value: items.extend(self._object_to_formfata(key + "[]", v)) return items @@ -175,7 +175,7 @@ class HttpClient: def _parse_response( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], response: httpx.Response, enable_stream: bool, request_param: ClientRequestParam, @@ -224,7 +224,7 @@ class HttpClient: def request( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], params: ClientRequestParam, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, @@ -259,7 +259,7 @@ class HttpClient: self, path: str, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, enable_stream: bool = False, ) -> ResponseT | StreamResponse: @@ -274,7 +274,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, enable_stream: bool = False, @@ -294,7 +294,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) @@ -308,7 +308,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: @@ -324,7 +324,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py index bbf2e72e68..b0a91d04a9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import time import cachetools.func diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index 2406e57820..a3f49ba846 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Any, Union +from typing import Any, ClassVar, Union from httpx import Timeout from pydantic import ConfigDict -from typing_extensions import ClassVar, TypedDict, Unpack +from typing_extensions import TypedDict, Unpack from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query from ._utils import remove_notgiven_indict @@ -17,7 +17,7 @@ class UserRequestInput(TypedDict, total=False): params: Query | None -class ClientRequestParam(): +class ClientRequestParam: method: str url: str max_retries: Union[int, NotGiven] = NotGiven() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 116246e645..2f831b6fc9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -1,11 +1,11 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin import httpx import pydantic -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec from ._base_type import NoneType from ._sse_client import StreamResponse @@ -19,7 +19,7 @@ R = TypeVar("R") class HttpResponse(Generic[R]): _cast_type: type[R] - _client: "HttpClient" + _client: HttpClient _parsed: R | None _enable_stream: bool _stream_cls: type[StreamResponse[Any]] @@ -30,7 +30,7 @@ class HttpResponse(Generic[R]): *, raw_response: httpx.Response, cast_type: type[R], - client: "HttpClient", + client: HttpClient, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 6efe20edcb..66afbfd107 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import json -from typing import TYPE_CHECKING, Generic, Iterator, Mapping +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Generic import httpx @@ -36,8 +36,7 @@ class StreamResponse(Generic[ResponseT]): return self._stream_chunks.__next__() def __iter__(self) -> Iterator[ResponseT]: - for item in self._stream_chunks: - yield item + yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: @@ -62,7 +61,7 @@ class StreamResponse(Generic[ResponseT]): pass -class Event(object): +class Event: def __init__( self, event: str | None = None, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py index 78c97af65b..6b610567da 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Iterable, Mapping, TypeVar +from collections.abc import Iterable, Mapping +from typing import TypeVar from ._base_type import NotGiven diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index bae4197c50..f22f32d251 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,5 +19,5 @@ class AsyncCompletion(BaseModel): request_id: Optional[str] = None model: Optional[str] = None task_status: str - choices: List[CompletionChoice] + choices: list[CompletionChoice] usage: CompletionUsage \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index 524e218d39..b2a847c50c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,7 +19,7 @@ class CompletionMessageToolCall(BaseModel): class CompletionMessage(BaseModel): content: Optional[str] = None role: str - tool_calls: Optional[List[CompletionMessageToolCall]] = None + tool_calls: Optional[list[CompletionMessageToolCall]] = None class CompletionUsage(BaseModel): @@ -37,7 +37,7 @@ class CompletionChoice(BaseModel): class Completion(BaseModel): model: Optional[str] = None created: Optional[int] = None - choices: List[CompletionChoice] + choices: list[CompletionChoice] request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py index c2e0c57666..c250699741 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -32,7 +32,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): content: Optional[str] = None role: Optional[str] = None - tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + tool_calls: Optional[list[ChoiceDeltaToolCall]] = None class Choice(BaseModel): @@ -49,7 +49,7 @@ class CompletionUsage(BaseModel): class ChatCompletionChunk(BaseModel): id: Optional[str] = None - choices: List[Choice] + choices: list[Choice] created: Optional[int] = None model: Optional[str] = None usage: Optional[CompletionUsage] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py index a8737cf8dc..e01f2c815f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -12,11 +12,11 @@ __all__ = ["Embedding", "EmbeddingsResponded"] class Embedding(BaseModel): object: str index: Optional[int] = None - embedding: List[float] + embedding: list[float] class EmbeddingsResponded(BaseModel): object: str - data: List[Embedding] + data: list[Embedding] model: str usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 94db046bd6..917bda7576 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -20,5 +20,5 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): object: Optional[str] = None - data: List[FileObject] + data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 6197b6faaf..71c00eaff0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel @@ -34,7 +34,7 @@ class FineTuningJob(BaseModel): object: Optional[str] = None - result_files: List[str] + result_files: list[str] status: str @@ -47,5 +47,5 @@ class FineTuningJob(BaseModel): class ListOfFineTuningJob(BaseModel): object: Optional[str] = None - data: List[FineTuningJob] + data: list[FineTuningJob] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py index 6ff3f77fd7..e26b448534 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel @@ -31,5 +31,5 @@ class JobEvent(BaseModel): class FineTuningJobEvent(BaseModel): object: Optional[str] = None - data: List[JobEvent] + data: list[JobEvent] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py index c661f7cdd5..e1ebc352bc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Union +from typing import Literal, Union -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict __all__ = ["Hyperparameters"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py index 429a7e25bc..b352ce0954 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -15,4 +15,4 @@ class GeneratedImage(BaseModel): class ImagesResponded(BaseModel): created: int - data: List[GeneratedImage] + data: list[GeneratedImage] diff --git a/api/core/model_runtime/utils/_compat.py b/api/core/model_runtime/utils/_compat.py index 305edcac8f..5c34152751 100644 --- a/api/core/model_runtime/utils/_compat.py +++ b/api/core/model_runtime/utils/_compat.py @@ -1,8 +1,7 @@ -from typing import Any +from typing import Any, Literal from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Literal PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index d0d93c34b9..cf6c98e01a 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -1,13 +1,14 @@ import dataclasses import datetime from collections import defaultdict, deque +from collections.abc import Callable from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -46,7 +47,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { bytes: lambda o: o.decode(), Color: str, datetime.date: isoformat, @@ -77,9 +78,9 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]] -) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: - encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( tuple ) for type_, encoder in type_encoder_map.items(): @@ -96,7 +97,7 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, ) -> Any: custom_encoder = custom_encoder or {} @@ -109,7 +110,7 @@ def jsonable_encoder( return encoder_instance(obj) if isinstance(obj, BaseModel): # TODO: remove when deprecating Pydantic v1 - encoders: Dict[Any, Any] = {} + encoders: dict[Any, Any] = {} if not PYDANTIC_V2: encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: @@ -149,7 +150,7 @@ def jsonable_encoder( return obj.value if isinstance(obj, PurePath): return str(obj) - if isinstance(obj, (str, int, float, type(None))): + if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): return format(obj, 'f') @@ -184,7 +185,7 @@ def jsonable_encoder( ) encoded_dict[encoded_key] = encoded_value return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] for item in obj: encoded_list.append( @@ -209,7 +210,7 @@ def jsonable_encoder( try: data = dict(obj) except Exception as e: - errors: List[Exception] = [] + errors: list[Exception] = [] errors.append(e) try: data = vars(obj) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 5ffcaaec65..0a373b7c42 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -2,7 +2,7 @@ import enum import json import os import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, @@ -67,11 +67,11 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) -> \ - Tuple[List[PromptMessage], Optional[List[str]]]: + tuple[list[PromptMessage], Optional[list[str]]]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -115,10 +115,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -182,7 +182,7 @@ class PromptTransform: ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> List[PromptMessage]: + max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit @@ -217,7 +217,7 @@ class PromptTransform: json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') # Open the JSON file and read its content - with open(json_file_path, 'r', encoding='utf-8') as json_file: + with open(json_file_path, encoding='utf-8') as json_file: return json.load(json_file) def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, @@ -225,9 +225,9 @@ class PromptTransform: inputs: dict, query: str, context: Optional[str], - files: List[FileObj], + files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: prompt_messages = [] context_prompt_content = '' @@ -280,8 +280,8 @@ class PromptTransform: query: str, context: Optional[str], memory: Optional[TokenBufferMemory], - files: List[FileObj], - model_config: ModelConfigEntity) -> List[PromptMessage]: + files: list[FileObj], + model_config: ModelConfigEntity) -> list[PromptMessage]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) @@ -451,10 +451,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix @@ -494,10 +494,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] @@ -535,7 +535,7 @@ class PromptTransform: def _get_completion_app_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - context: Optional[str]) -> List[PromptMessage]: + context: Optional[str]) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt prompt_messages = [] @@ -554,8 +554,8 @@ class PromptTransform: def _get_completion_app_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - files: List[FileObj], - context: Optional[str]) -> List[PromptMessage]: + files: list[FileObj], + context: Optional[str]) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index 984cdb4003..a675dfc568 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from langchain.schema import Document @@ -9,8 +9,8 @@ class RerankRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: List[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> List[Document]: + def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: """ Run rerank model :param query: search query diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/splitter/fixed_text_splitter.py index babb360a5e..285a7ba14e 100644 --- a/api/core/splitter/fixed_text_splitter.py +++ b/api/core/splitter/fixed_text_splitter.py @@ -1,7 +1,7 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast from langchain.text_splitter import ( TS, @@ -28,8 +28,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: Type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Union[Literal[all], AbstractSet[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", **kwargs: Any, ): def _token_encoder(text: str) -> int: @@ -59,13 +59,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] - def split_text(self, text: str) -> List[str]: + def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) @@ -81,7 +81,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) return final_chunks - def recursive_split_text(self, text: str) -> List[str]: + def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/third_party/langchain/llms/fake.py index 64117477e1..ab5152b38d 100644 --- a/api/core/third_party/langchain/llms/fake.py +++ b/api/core/third_party/langchain/llms/fake.py @@ -1,5 +1,6 @@ import time -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel @@ -19,8 +20,8 @@ class FakeLLM(SimpleChatModel): def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -36,8 +37,8 @@ class FakeLLM(SimpleChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py index 3bb2bb5eaa..208490a5bf 100644 --- a/api/core/tool/current_datetime_tool.py +++ b/api/core/tool/current_datetime_tool.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Type from langchain.tools import BaseTool from pydantic import BaseModel, Field @@ -12,7 +11,7 @@ class DatetimeToolInput(BaseModel): class DatetimeTool(BaseTool): """Tool for querying current datetime.""" name: str = "current_datetime" - args_schema: Type[BaseModel] = DatetimeToolInput + args_schema: type[BaseModel] = DatetimeToolInput description: str = "A tool when you want to get the current date, time, week, month or year, " \ "and the time zone is UTC. Result is \"