refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166)

This commit is contained in:
-LAN- 2024-09-10 15:00:25 +08:00 committed by GitHub
parent af92f19291
commit ed37439ef7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 199 additions and 197 deletions

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str, def __init__(
application_generate_entity: AgentChatAppGenerateEntity, self,
conversation: Conversation, tenant_id: str,
app_config: AgentChatAppConfig, application_generate_entity: AgentChatAppGenerateEntity,
model_config: ModelConfigWithCredentialsEntity, conversation: Conversation,
config: AgentEntity, app_config: AgentChatAppConfig,
queue_manager: AppQueueManager, model_config: ModelConfigWithCredentialsEntity,
message: Message, config: AgentEntity,
user_id: str, queue_manager: AppQueueManager,
memory: Optional[TokenBufferMemory] = None, message: Message,
prompt_messages: Optional[list[PromptMessage]] = None, user_id: str,
variables_pool: Optional[ToolRuntimeVariablePool] = None, memory: Optional[TokenBufferMemory] = None,
db_variables: Optional[ToolConversationVariables] = None, prompt_messages: Optional[list[PromptMessage]] = None,
model_instance: ModelInstance = None variables_pool: Optional[ToolRuntimeVariablePool] = None,
) -> None: db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
self.message = message self.message = message
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history( self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance
@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback hit_callback=hit_callback,
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = db.session.query(MessageAgentThought).filter( self.agent_thought_count = (
MessageAgentThought.message_id == self.message.id, db.session.query(MessageAgentThought)
).count() .filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
db.session.close() db.session.close()
# check if model supports stream tool call # check if model supports stream tool call
@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
self.query = None self.query = None
self._current_thoughts: list[PromptMessage] = [] self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ def _repack_app_generate_entity(
-> AgentChatAppGenerateEntity: self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
""" """
Repack app generate entity Repack app generate entity
""" """
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = '' app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
return app_generate_entity return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
""" """
convert tool to prompt message tool convert tool to prompt message tool
""" """
tool_entity = ToolManager.get_agent_tool_runtime( tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_config.app_id, app_id=self.app_config.app_id,
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from invoke_from=self.application_generate_entity.invoke_from,
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)
@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
} },
) )
parameters = tool_entity.get_all_runtime_parameters() parameters = tool_entity.get_all_runtime_parameters()
@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
message_tool.parameters['properties'][parameter.name] = { message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '', "description": parameter.llm_description or "",
} }
if len(enum) > 0: if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required: if parameter.required:
message_tool.parameters['required'].append(parameter.name) message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
} },
) )
for parameter in tool.get_runtime_parameters(): for parameter in tool.get_runtime_parameters():
parameter_type = 'string' parameter_type = "string"
prompt_tool.parameters['properties'][parameter.name] = { prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '', "description": parameter.llm_description or "",
} }
if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters['required']: if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters['required'].append(parameter.name) prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool return prompt_tool
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
""" """
Init tools Init tools
""" """
@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
prompt_tool.parameters['properties'][parameter.name] = { prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or '', "description": parameter.llm_description or "",
} }
if len(enum) > 0: if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters['required']: if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters['required'].append(parameter.name) prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool return prompt_tool
def create_agent_thought(self, message_id: str, message: str, def create_agent_thought(
tool_name: str, tool_input: str, messages_ids: list[str] self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought: ) -> MessageAgentThought:
""" """
Create agent thought Create agent thought
""" """
thought = MessageAgentThought( thought = MessageAgentThought(
message_id=message_id, message_id=message_id,
message_chain_id=None, message_chain_id=None,
thought='', thought="",
tool=tool_name, tool=tool_name,
tool_labels_str='{}', tool_labels_str="{}",
tool_meta_str='{}', tool_meta_str="{}",
tool_input=tool_input, tool_input=tool_input,
message=message, message=message,
message_token=0, message_token=0,
message_unit_price=0, message_unit_price=0,
message_price_unit=0, message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else '', message_files=json.dumps(messages_ids) if messages_ids else "",
answer='', answer="",
observation='', observation="",
answer_token=0, answer_token=0,
answer_unit_price=0, answer_unit_price=0,
answer_price_unit=0, answer_price_unit=0,
tokens=0, tokens=0,
total_price=0, total_price=0,
position=self.agent_thought_count + 1, position=self.agent_thought_count + 1,
currency='USD', currency="USD",
latency=0, latency=0,
created_by_role='account', created_by_role="account",
created_by=self.user_id, created_by=self.user_id,
) )
@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
return thought return thought
def save_agent_thought(self, def save_agent_thought(
agent_thought: MessageAgentThought, self,
tool_name: str, agent_thought: MessageAgentThought,
tool_input: Union[str, dict], tool_name: str,
thought: str, tool_input: Union[str, dict],
observation: Union[str, dict], thought: str,
tool_invoke_meta: Union[str, dict], observation: Union[str, dict],
answer: str, tool_invoke_meta: Union[str, dict],
messages_ids: list[str], answer: str,
llm_usage: LLMUsage = None) -> MessageAgentThought: messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter( agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought
@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
observation = json.dumps(observation, ensure_ascii=False) observation = json.dumps(observation, ensure_ascii=False)
except Exception as e: except Exception as e:
observation = json.dumps(observation) observation = json.dumps(observation)
agent_thought.observation = observation agent_thought.observation = observation
if answer is not None: if answer is not None:
@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids) agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit agent_thought.message_price_unit = llm_usage.prompt_price_unit
@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {} labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(';') if agent_thought.tool else [] tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
if tool_label: if tool_label:
labels[tool] = tool_label.to_dict() labels[tool] = tool_label.to_dict()
else: else:
labels[tool] = {'en_US': tool, 'zh_Hans': tool} labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = db.session.query(ToolConversationVariables).filter( db_variables = (
ToolConversationVariables.conversation_id == self.message.conversation_id, db.session.query(ToolConversationVariables)
).first() .filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage): if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message) result.append(prompt_message)
messages: list[Message] = db.session.query(Message).filter( messages: list[Message] = (
Message.conversation_id == self.message.conversation_id, db.session.query(Message)
).order_by(Message.created_at.asc()).all() .filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.all()
)
for message in messages: for message in messages:
if message.id == self.message.id: if message.id == self.message.id:
@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
for agent_thought in agent_thoughts: for agent_thought in agent_thoughts:
tools = agent_thought.tool tools = agent_thought.tool
if tools: if tools:
tools = tools.split(';') tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = [] tool_call_response: list[ToolPromptMessage] = []
try: try:
tool_inputs = json.loads(agent_thought.tool_input) tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e: except Exception as e:
tool_inputs = { tool: {} for tool in tools } tool_inputs = {tool: {} for tool in tools}
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e: except Exception as e:
@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
for tool in tools: for tool in tools:
# generate a uuid for tool call # generate a uuid for tool call
tool_call_id = str(uuid.uuid4()) tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall( tool_calls.append(
id=tool_call_id, AssistantPromptMessage.ToolCall(
type='function', id=tool_call_id,
function=AssistantPromptMessage.ToolCall.ToolCallFunction( type="function",
name=tool, function=AssistantPromptMessage.ToolCall.ToolCallFunction(
arguments=json.dumps(tool_inputs.get(tool, {})), name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
) )
)) )
tool_call_response.append(ToolPromptMessage( tool_call_response.append(
content=tool_responses.get(tool, agent_thought.observation), ToolPromptMessage(
name=tool, content=tool_responses.get(tool, agent_thought.observation),
tool_call_id=tool_call_id, name=tool,
)) tool_call_id=tool_call_id,
)
)
result.extend([ result.extend(
AssistantPromptMessage( [
content=agent_thought.thought, AssistantPromptMessage(
tool_calls=tool_calls, content=agent_thought.thought,
), tool_calls=tool_calls,
*tool_call_response ),
]) *tool_call_response,
]
)
if not tools: if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought)) result.append(AssistantPromptMessage(content=agent_thought.thought))
else: else:
@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.transform_message_files( file_objs = message_file_parser.transform_message_files(files, file_extra_config)
files,
file_extra_config
)
else: else:
file_objs = [] file_objs = []

View File

@ -1,6 +1,6 @@
import logging import logging
import os import os
from collections.abc import Callable, Generator from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -41,7 +41,7 @@ class ModelInstance:
configuration=provider_model_bundle.configuration, configuration=provider_model_bundle.configuration,
model_type=provider_model_bundle.model_type_instance.model_type, model_type=provider_model_bundle.model_type_instance.model_type,
model=model, model=model,
credentials=self.credentials credentials=self.credentials,
) )
@staticmethod @staticmethod
@ -54,10 +54,7 @@ class ModelInstance:
""" """
configuration = provider_model_bundle.configuration configuration = provider_model_bundle.configuration
model_type = provider_model_bundle.model_type_instance.model_type model_type = provider_model_bundle.model_type_instance.model_type
credentials = configuration.get_current_credentials( credentials = configuration.get_current_credentials(model_type=model_type, model=model)
model_type=model_type,
model=model
)
if credentials is None: if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
@ -65,10 +62,9 @@ class ModelInstance:
return credentials return credentials
@staticmethod @staticmethod
def _get_load_balancing_manager(configuration: ProviderConfiguration, def _get_load_balancing_manager(
model_type: ModelType, configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
model: str, ) -> Optional["LBModelManager"]:
credentials: dict) -> Optional["LBModelManager"]:
""" """
Get load balancing model credentials Get load balancing model credentials
:param configuration: provider configuration :param configuration: provider configuration
@ -81,8 +77,7 @@ class ModelInstance:
current_model_setting = None current_model_setting = None
# check if model is disabled by admin # check if model is disabled by admin
for model_setting in configuration.model_settings: for model_setting in configuration.model_settings:
if (model_setting.model_type == model_type if model_setting.model_type == model_type and model_setting.model == model:
and model_setting.model == model):
current_model_setting = model_setting current_model_setting = model_setting
break break
@ -95,17 +90,23 @@ class ModelInstance:
model_type=model_type, model_type=model_type,
model=model, model=model,
load_balancing_configs=current_model_setting.load_balancing_configs, load_balancing_configs=current_model_setting.load_balancing_configs,
managed_credentials=credentials if configuration.custom_configuration.provider else None managed_credentials=credentials if configuration.custom_configuration.provider else None,
) )
return lb_model_manager return lb_model_manager
return None return None
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, def invoke_llm(
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, self,
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ prompt_messages: list[PromptMessage],
-> Union[LLMResult, Generator]: model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -132,11 +133,12 @@ class ModelInstance:
stop=stop, stop=stop,
stream=stream, stream=stream,
user=user, user=user,
callbacks=callbacks callbacks=callbacks,
) )
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], def get_llm_num_tokens(
tools: Optional[list[PromptMessageTool]] = None) -> int: self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm
@ -153,11 +155,10 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools tools=tools,
) )
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
-> TextEmbeddingResult:
""" """
Invoke large language model Invoke large language model
@ -174,7 +175,7 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts, texts=texts,
user=user user=user,
) )
def get_text_embedding_num_tokens(self, texts: list[str]) -> int: def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
@ -192,13 +193,17 @@ class ModelInstance:
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts texts=texts,
) )
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, def invoke_rerank(
top_n: Optional[int] = None, self,
user: Optional[str] = None) \ query: str,
-> RerankResult: docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
""" """
Invoke rerank model Invoke rerank model
@ -221,11 +226,10 @@ class ModelInstance:
docs=docs, docs=docs,
score_threshold=score_threshold, score_threshold=score_threshold,
top_n=top_n, top_n=top_n,
user=user user=user,
) )
def invoke_moderation(self, text: str, user: Optional[str] = None) \ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
-> bool:
""" """
Invoke moderation model Invoke moderation model
@ -242,11 +246,10 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
text=text, text=text,
user=user user=user,
) )
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
-> str:
""" """
Invoke large language model Invoke large language model
@ -263,11 +266,10 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
file=file, file=file,
user=user user=user,
) )
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
-> str:
""" """
Invoke large language tts model Invoke large language tts model
@ -288,7 +290,7 @@ class ModelInstance:
content_text=content_text, content_text=content_text,
user=user, user=user,
tenant_id=tenant_id, tenant_id=tenant_id,
voice=voice voice=voice,
) )
def _round_robin_invoke(self, function: Callable, *args, **kwargs): def _round_robin_invoke(self, function: Callable, *args, **kwargs):
@ -312,8 +314,8 @@ class ModelInstance:
raise last_exception raise last_exception
try: try:
if 'credentials' in kwargs: if "credentials" in kwargs:
del kwargs['credentials'] del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials) return function(*args, **kwargs, credentials=lb_config.credentials)
except InvokeRateLimitError as e: except InvokeRateLimitError as e:
# expire in 60 seconds # expire in 60 seconds
@ -340,9 +342,7 @@ class ModelInstance:
self.model_type_instance = cast(TTSModel, self.model_type_instance) self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices( return self.model_type_instance.get_tts_model_voices(
model=self.model, model=self.model, credentials=self.credentials, language=language
credentials=self.credentials,
language=language
) )
@ -363,9 +363,7 @@ class ModelManager:
return self.get_default_model_instance(tenant_id, model_type) return self.get_default_model_instance(tenant_id, model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle( provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id, tenant_id=tenant_id, provider=provider, model_type=model_type
provider=provider,
model_type=model_type
) )
return ModelInstance(provider_model_bundle, model) return ModelInstance(provider_model_bundle, model)
@ -386,10 +384,7 @@ class ModelManager:
:param model_type: model type :param model_type: model type
:return: :return:
""" """
default_model_entity = self._provider_manager.get_default_model( default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
tenant_id=tenant_id,
model_type=model_type
)
if not default_model_entity: if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}") raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
@ -398,17 +393,20 @@ class ModelManager:
tenant_id=tenant_id, tenant_id=tenant_id,
provider=default_model_entity.provider.provider, provider=default_model_entity.provider.provider,
model_type=model_type, model_type=model_type,
model=default_model_entity.model model=default_model_entity.model,
) )
class LBModelManager: class LBModelManager:
def __init__(self, tenant_id: str, def __init__(
provider: str, self,
model_type: ModelType, tenant_id: str,
model: str, provider: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration], model_type: ModelType,
managed_credentials: Optional[dict] = None) -> None: model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None,
) -> None:
""" """
Load balancing model manager Load balancing model manager
:param tenant_id: tenant_id :param tenant_id: tenant_id
@ -439,10 +437,7 @@ class LBModelManager:
:return: :return:
""" """
cache_key = "model_lb_index:{}:{}:{}:{}".format( cache_key = "model_lb_index:{}:{}:{}:{}".format(
self._tenant_id, self._tenant_id, self._provider, self._model_type.value, self._model
self._provider,
self._model_type.value,
self._model
) )
cooldown_load_balancing_configs = [] cooldown_load_balancing_configs = []
@ -473,10 +468,12 @@ class LBModelManager:
continue continue
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): if bool(os.environ.get("DEBUG", "False").lower() == "true"):
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" logger.info(
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}") f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}"
)
return config return config
@ -490,14 +487,10 @@ class LBModelManager:
:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id, self._tenant_id, self._provider, self._model_type.value, self._model, config.id
self._provider,
self._model_type.value,
self._model,
config.id
) )
redis_client.setex(cooldown_cache_key, expire, 'true') redis_client.setex(cooldown_cache_key, expire, "true")
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
""" """
@ -506,11 +499,7 @@ class LBModelManager:
:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id, self._tenant_id, self._provider, self._model_type.value, self._model, config.id
self._provider,
self._model_type.value,
self._model,
config.id
) )
res = redis_client.exists(cooldown_cache_key) res = redis_client.exists(cooldown_cache_key)
@ -518,11 +507,9 @@ class LBModelManager:
return res return res
@staticmethod @staticmethod
def get_config_in_cooldown_and_ttl(tenant_id: str, def get_config_in_cooldown_and_ttl(
provider: str, tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
model_type: ModelType, ) -> tuple[bool, int]:
model: str,
config_id: str) -> tuple[bool, int]:
""" """
Get model load balancing config is in cooldown and ttl Get model load balancing config is in cooldown and ttl
:param tenant_id: workspace id :param tenant_id: workspace id
@ -533,11 +520,7 @@ class LBModelManager:
:return: :return:
""" """
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
tenant_id, tenant_id, provider, model_type.value, model, config_id
provider,
model_type.value,
model,
config_id
) )
ttl = redis_client.ttl(cooldown_cache_key) ttl = redis_client.ttl(cooldown_cache_key)