chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)

This commit is contained in:
Bowen Liang 2024-02-09 15:21:33 +08:00 committed by GitHub
parent 589099a005
commit 063191889d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
246 changed files with 912 additions and 937 deletions

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import os
from werkzeug.exceptions import Unauthorized

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import os
import dotenv

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request

View File

@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
import flask_login
from flask import Response, stream_with_context
@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,6 +1,7 @@
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_login import current_user

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from decimal import Decimal

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import flask_login
from flask import current_app, request
from flask_restful import Resource, reqparse

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import flask_restful
from flask import current_app, request
from flask_login import current_user

View File

@ -1,6 +1,4 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from typing import List
from flask import request
from flask_login import current_user
@ -71,7 +69,7 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import uuid
from datetime import datetime

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request

View File

@ -1,8 +1,8 @@
# -*- coding:utf-8 -*-
import json
import logging
from collections.abc import Generator
from datetime import datetime
from typing import Generator, Union
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from flask_login import current_user

View File

@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from functools import wraps
from flask import current_app, request

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
import logging

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
import pytz

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import current_user
from flask_restful import Resource, abort, fields, marshal_with, reqparse

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from functools import wraps

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app

View File

@ -1,6 +1,7 @@
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from functools import wraps

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request

View File

@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import fields, marshal_with, reqparse
@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import uuid
from flask import request

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_restful import fields, marshal_with

View File

@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from functools import wraps
from flask import request

View File

@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Optional
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback

View File

@ -1,4 +1,4 @@
from typing import List, cast
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
return new_messages
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/

View File

@ -1,5 +1,6 @@
import re
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""

View File

@ -1,5 +1,6 @@
import re
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:

View File

@ -1,5 +1,6 @@
import time
from typing import Generator, List, Optional, Tuple, Union, cast
from collections.abc import Generator
from typing import Optional, Union, cast
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import (
@ -84,7 +85,7 @@ class AppRunner:
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: List[PromptMessage]):
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
@ -126,7 +127,7 @@ class AppRunner:
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> Tuple[List[PromptMessage], Optional[List[str]]]:
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
:param context:
@ -295,7 +296,7 @@ class AppRunner:
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
query: str) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id

View File

@ -1,7 +1,8 @@
import json
import logging
import time
from typing import Generator, Optional, Union, cast
from collections.abc import Generator
from typing import Optional, Union, cast
from pydantic import BaseModel
@ -118,7 +119,7 @@ class GenerateTaskPipeline:
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
@ -201,7 +202,7 @@ class GenerateTaskPipeline:
data = self._error_to_stream_response_data(self._handle_error(event))
yield self._yield_response(data)
break
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
@ -353,7 +354,7 @@ class GenerateTaskPipeline:
yield self._yield_response(response)
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:

View File

@ -1,7 +1,7 @@
import logging
import threading
import time
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask import Flask, current_app
from pydantic import BaseModel
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
config: dict[str, Any]
class OutputModerationHandler(BaseModel):

View File

@ -2,7 +2,8 @@ import json
import logging
import threading
import uuid
from typing import Any, Generator, Optional, Tuple, Union, cast
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from pydantic import ValidationError
@ -585,7 +586,7 @@ class ApplicationManager:
return AppOrchestrationConfigEntity(**properties)
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> Tuple[Conversation, Message]:
-> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity

View File

@ -1,7 +1,8 @@
import queue
import time
from collections.abc import Generator
from enum import Enum
from typing import Any, Generator
from typing import Any
from sqlalchemy.orm import DeclarativeMeta

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Optional, Union, cast
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._message_agent_thought = None
@property
def agent_loops(self) -> List[AgentLoop]:
def agent_loops(self) -> list[AgentLoop]:
return self._agent_loops
def clear_agent_loops(self) -> None:
@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
pass
@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def on_tool_start(
self,
tool_name: str,
tool_inputs: Dict[str, Any],
tool_inputs: dict[str, Any],
) -> None:
"""Do nothing."""
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def on_tool_end(
self,
tool_name: str,
tool_inputs: Dict[str, Any],
tool_inputs: dict[str, Any],
tool_outputs: str,
) -> None:
"""If not the final action, print out observation."""

View File

@ -1,4 +1,3 @@
from typing import List
from langchain.schema import Document
@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query)
db.session.commit()
def on_tool_end(self, documents: List[Document]) -> None:
def on_tool_end(self, documents: list[Document]) -> None:
"""Handle tool end."""
for document in documents:
doc_id = document.metadata['doc_id']
@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler:
db.session.commit()
def return_retriever_resource_info(self, resource: List):
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:

View File

@ -1,6 +1,6 @@
import os
import sys
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
model_config: ModelConfigEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: Dict[str, Any] = {}
parameters: dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""

View File

@ -1,6 +1,6 @@
import tempfile
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union
import requests
from flask import current_app
@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
@ -37,7 +37,7 @@ class FileExtractor:
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
@ -53,7 +53,7 @@ class FileExtractor:
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[List[Document], str]:
is_automatic: bool = False) -> Union[list[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()

View File

@ -1,6 +1,6 @@
import csv
import logging
from typing import Dict, List, Optional
from typing import Optional
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader):
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
csv_args: Optional[dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader):
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
def load(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
def load(self) -> list[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
def load(self) -> list[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:

View File

@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional, Tuple, cast
from typing import Optional, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader):
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
def load(self) -> list[Document]:
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader):
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
markdown_tups: List[Tuple[Optional[str], str]] = []
markdown_tups: list[tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
current_header = None
@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader):
content = re.sub(pattern, r"\1", content)
return content
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
"""Parse file into tuples."""
content = ""
try:
with open(filepath, "r", encoding=self._encoding) as f:
with open(filepath, encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import requests
from flask import current_app
@ -67,7 +67,7 @@ class NotionLoader(BaseLoader):
document_model=document_model
)
def load(self) -> List[Document]:
def load(self) -> list[Document]:
self.update_last_edited_time(
self._document_model
)
@ -78,7 +78,7 @@ class NotionLoader(BaseLoader):
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> List[Document]:
) -> list[Document]:
docs = []
if notion_page_type == 'database':
# get all the pages in the database
@ -94,8 +94,8 @@ class NotionLoader(BaseLoader):
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[Document]:
self, database_id: str, query_dict: dict[str, Any] = {}
) -> list[Document]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
@ -149,12 +149,12 @@ class NotionLoader(BaseLoader):
return database_content_list
def _get_notion_block_data(self, page_id: str) -> List[str]:
def _get_notion_block_data(self, page_id: str) -> list[str]:
result_lines_arr = []
cur_block_id = page_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
@ -216,7 +216,7 @@ class NotionLoader(BaseLoader):
cur_block_id = block_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
@ -280,7 +280,7 @@ class NotionLoader(BaseLoader):
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",
@ -346,7 +346,7 @@ class NotionLoader(BaseLoader):
else:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
query_dict: Dict[str, Any] = {}
query_dict: dict[str, Any] = {}
res = requests.request(
"GET",

View File

@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
@ -28,7 +28,7 @@ class PdfLoader(BaseLoader):
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
def load(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:

View File

@ -1,6 +1,5 @@
import base64
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.ppt import partition_ppt
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.text import partition_text
elements = partition_text(filename=self._file_path, api_url=self._api_url)

View File

@ -1,5 +1,4 @@
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
self._file_path = file_path
self._api_url = api_url
def load(self) -> List[Document]:
def load(self) -> list[Document]:
from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)

View File

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Sequence, cast
from collections.abc import Sequence
from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func
@ -22,10 +23,10 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
return cls(**config_dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Serialize to dict."""
return {
"dataset_id": self._dataset.id,
@ -40,7 +41,7 @@ class DatasetDocumentStore:
return self._user_id
@property
def docs(self) -> Dict[str, Document]:
def docs(self) -> dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()

View File

@ -1,6 +1,6 @@
import base64
import logging
from typing import List, Optional, cast
from typing import Optional, cast
import numpy as np
from langchain.embeddings.base import Embeddings
@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
self._model_instance = model_instance
self._user = user
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
text_embeddings = []
try:
@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)

View File

@ -1,8 +1,9 @@
import datetime
import json
import logging
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Optional
from pydantic import BaseModel
@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.provider_credential_schema else []
)
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> Tuple[ProviderModel, dict]:
-> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
Model class for provider configuration dict.
"""
tenant_id: str
configurations: Dict[str, ProviderConfiguration] = {}
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
return all_models
def to_list(self) -> List[ProviderConfiguration]:
def to_list(self) -> list[ProviderConfiguration]:
"""
Convert to list.

View File

@ -61,7 +61,7 @@ class Extensible:
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r', encoding='utf-8') as f:
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
@ -93,7 +93,7 @@ class Extensible:
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
with open(json_path, encoding='utf-8') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(

View File

@ -2,7 +2,7 @@ import json
import logging
from datetime import datetime
from mimetypes import guess_extension
from typing import List, Optional, Tuple, Union, cast
from typing import Optional, Union, cast
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager
@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner):
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[List[PromptMessage]] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return app_orchestration_config
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return result
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
"""
@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return prompt_tool
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return result
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
"""
Create message file
@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return result
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: List[str]
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner):
thought: str,
observation: str,
answer: str,
messages_ids: List[str],
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
"""
Save agent thought
@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.commit()
def get_history_prompt_messages(self) -> List[PromptMessage]:
def get_history_prompt_messages(self) -> list[PromptMessage]:
"""
Get history prompt messages
"""
@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return self.history_prompt_messages
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
Transform tool message into agent thought
"""

View File

@ -1,6 +1,7 @@
import json
import re
from typing import Dict, Generator, List, Literal, Union
from collections.abc import Generator
from typing import Literal, Union
from core.application_queue_manager import PublishFrom
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, conversation: Conversation,
message: Message,
query: str,
inputs: Dict[str, str],
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
app_orchestration_config = self.app_orchestration_config
self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: List[AgentScratchpadUnit] = []
agent_scratchpad: list[AgentScratchpadUnit] = []
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
}
final_answer = ''
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration")
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
convert agent scratchpad list to str
"""
@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return result
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
agent_scratchpad: List[AgentScratchpadUnit],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
agent_scratchpad: list[AgentScratchpadUnit],
agent_prompt_message: AgentPromptEntity,
instruction: str,
input: str,
) -> List[PromptMessage]:
) -> list[PromptMessage]:
"""
organize chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.

View File

@ -1,6 +1,7 @@
import json
import logging
from typing import Any, Dict, Generator, List, Tuple, Union
from collections.abc import Generator
from typing import Any, Union
from core.application_queue_manager import PublishFrom
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# continue to run until there is not any tool call
function_call_state = True
agent_thoughts: List[MessageAgentThought] = []
agent_thoughts: list[MessageAgentThought] = []
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
callbacks=[],
)
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ''
@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result

View File

@ -1,4 +1,4 @@
from typing import List, Optional, cast
from typing import Optional, cast
from langchain.tools import BaseTool
@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[List[BaseTool]]:
-> Optional[list[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id

View File

@ -2,7 +2,7 @@ import concurrent
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple
from typing import Optional
from flask import Flask, current_app
@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str) -> Tuple[Optional[str], Optional[str]]:
query: str) -> tuple[Optional[str], Optional[str]]:
"""
Query external data tool.
:param flask_app: flask app

View File

@ -1,5 +1,4 @@
import logging
from typing import Tuple
from core.entities.application_entities import AppOrchestrationConfigEntity
from core.moderation.base import ModerationAction, ModerationException
@ -13,7 +12,7 @@ class ModerationFeature:
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
query: str) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import requests
@ -15,8 +15,8 @@ class MessageFileParser:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> List[FileObj]:
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> list[FileObj]:
"""
validate and transform files arg
@ -96,7 +96,7 @@ class MessageFileParser:
# return all file objs
return new_files
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
"""
transform message files
@ -110,8 +110,8 @@ class MessageFileParser:
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
"""
transform files to file objs
@ -119,7 +119,7 @@ class MessageFileParser:
:param file_upload_config:
:return:
"""
type_file_objs: Dict[FileType, List[FileObj]] = {
type_file_objs: dict[FileType, list[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any
from langchain.schema import BaseRetriever, Document
@ -53,7 +53,7 @@ class BaseIndex(ABC):
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:

View File

@ -1,5 +1,4 @@
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:

View File

@ -1,6 +1,6 @@
import json
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain.schema import BaseRetriever, Document
from pydantic import BaseModel, Extra, Field
@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex):
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex):
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
chunk_indices_count: dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex):
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex):
document_segment.keywords = keywords
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: List[str]):
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(self, query: str) -> list[Document]:
"""Get documents relevant for a query.
Args:
@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(self, query: str) -> list[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")

View File

@ -1,7 +1,7 @@
import json
import logging
from abc import abstractmethod
from typing import Any, List, cast
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex):
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

View File

@ -1,4 +1,4 @@
from typing import Any, List, cast
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex):
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []

View File

@ -1,5 +1,5 @@
import os
from typing import Any, List, Optional, cast
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, cast
from typing import Any, Optional, cast
import requests
import weaviate
@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)

View File

@ -5,7 +5,7 @@ import re
import threading
import time
import uuid
from typing import List, Optional, cast
from typing import Optional, cast
from flask import Flask, current_app
from flask_login import current_user
@ -40,7 +40,7 @@ class IndexingRunner:
self.storage = storage
self.model_manager = ModelManager()
def run(self, dataset_documents: List[DatasetDocument]):
def run(self, dataset_documents: list[DatasetDocument]):
"""Run the indexing process."""
for dataset_document in dataset_documents:
try:
@ -238,7 +238,7 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
@ -494,7 +494,7 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
# load file
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
@ -526,7 +526,7 @@ class IndexingRunner:
)
# replace doc id to document model id
text_docs = cast(List[Document], text_docs)
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.page_content = self.filter_string(text_doc.page_content)
@ -540,7 +540,7 @@ class IndexingRunner:
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
# Unicode U+FFFE
text = re.sub(u'\uFFFE', '', text)
text = re.sub('\uFFFE', '', text)
return text
def _get_splitter(self, processing_rule: DatasetProcessRule,
@ -577,9 +577,9 @@ class IndexingRunner:
return character_splitter
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
-> List[Document]:
-> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
@ -624,9 +624,9 @@ class IndexingRunner:
return documents
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule, tenant_id: str,
document_form: str, document_language: str) -> List[Document]:
document_form: str, document_language: str) -> list[Document]:
"""
Split the text documents into nodes.
"""
@ -699,8 +699,8 @@ class IndexingRunner:
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> list[Document]:
"""
Split the text documents into nodes.
"""
@ -770,7 +770,7 @@ class IndexingRunner:
for q, a in matches if q and a
]
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
"""
Build the index for the document.
"""
@ -877,7 +877,7 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""

View File

@ -1,4 +1,5 @@
from typing import IO, Generator, List, Optional, Union, cast
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderModelBundle
from core.errors.error import ProviderTokenNotInitError
@ -47,7 +48,7 @@ class ModelInstance:
return credentials
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""

View File

@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Optional
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
@ -23,7 +23,7 @@ class Callback(ABC):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
@ -42,7 +42,7 @@ class Callback(ABC):
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
@ -62,7 +62,7 @@ class Callback(ABC):
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
@ -82,7 +82,7 @@ class Callback(ABC):
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback

View File

@ -1,7 +1,7 @@
import json
import logging
import sys
from typing import List, Optional
from typing import Optional
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
@ -60,7 +60,7 @@ class LoggingCallback(Callback):
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
@ -81,7 +81,7 @@ class LoggingCallback(Callback):
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
@ -113,7 +113,7 @@ class LoggingCallback(Callback):
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback

View File

@ -1,8 +1,7 @@
from typing import Dict
from core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
'label': {
'en_US': 'Temperature',

View File

@ -153,7 +153,7 @@ class AIModel(ABC):
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, 'r', encoding='utf-8') as f:
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
@ -161,7 +161,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, 'r', encoding='utf-8') as f:
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
new_parameter_rules = []

View File

@ -3,7 +3,8 @@ import os
import re
import time
from abc import abstractmethod
from typing import Generator, List, Optional, Union
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel):
def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
"""
Invoke result generator
@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel):
@abstractmethod
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel):
"""
raise NotImplementedError
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel):
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger before invoke callbacks
@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel):
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger new chunk callbacks
@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel):
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger after invoke callbacks
@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel):
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
"""
Trigger invoke error callbacks
@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel):
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, (float, int)):
if not isinstance(parameter_value, float | int):
raise ValueError(f"Model Parameter {parameter_name} should be float.")
# validate parameter value precision

View File

@ -1,7 +1,6 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict
import yaml
@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
class ModelProvider(ABC):
provider_schema: ProviderEntity = None
model_instance_map: Dict[str, AIModel] = {}
model_instance_map: dict[str, AIModel] = {}
@abstractmethod
def validate_provider_credentials(self, credentials: dict) -> None:
@ -47,7 +46,7 @@ class ModelProvider(ABC):
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, 'r', encoding='utf-8') as f:
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
try:

View File

@ -1,4 +1,5 @@
from typing import Generator, List, Optional, Union
from collections.abc import Generator
from typing import Optional, Union
import anthropic
from anthropic import Anthropic, Stream
@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[List[str]] = None, stream: bool = True,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str:
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model

View File

@ -1,6 +1,7 @@
import copy
import logging
from typing import Generator, List, Optional, Union, cast
from collections.abc import Generator
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
@ -34,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
@ -121,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return ai_model_entity.entity if ai_model_entity else None
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
@ -239,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
@ -537,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage],
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.

Some files were not shown because too many files have changed in this diff Show More