mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 04:49:04 +08:00
refactor: tool parameter cache (#3703)
This commit is contained in:
parent
65ac4f69af
commit
3480f1c59e
@ -1,5 +1,3 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden
|
from werkzeug.exceptions import BadRequest, Forbidden
|
||||||
@ -8,17 +6,12 @@ from controllers.console import api
|
|||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from core.agent.entities import AgentToolEntity
|
|
||||||
from core.tools.tool_manager import ToolManager
|
|
||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from fields.app_fields import (
|
from fields.app_fields import (
|
||||||
app_detail_fields,
|
app_detail_fields,
|
||||||
app_detail_fields_with_site,
|
app_detail_fields_with_site,
|
||||||
app_pagination_fields,
|
app_pagination_fields,
|
||||||
)
|
)
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, AppMode, AppModelConfig
|
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||||
@ -108,43 +101,9 @@ class AppApi(Resource):
|
|||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get app detail"""
|
"""Get app detail"""
|
||||||
# get original app model config
|
app_service = AppService()
|
||||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
|
||||||
model_config: AppModelConfig = app_model.app_model_config
|
|
||||||
agent_mode = model_config.agent_mode_dict
|
|
||||||
# decrypt agent tool parameters if it's secret-input
|
|
||||||
for tool in agent_mode.get('tools') or []:
|
|
||||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
|
||||||
continue
|
|
||||||
agent_tool_entity = AgentToolEntity(**tool)
|
|
||||||
# get tool
|
|
||||||
try:
|
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
agent_tool=agent_tool_entity,
|
|
||||||
)
|
|
||||||
manager = ToolParameterConfigurationManager(
|
|
||||||
tenant_id=current_user.current_tenant_id,
|
|
||||||
tool_runtime=tool_runtime,
|
|
||||||
provider_name=agent_tool_entity.provider_id,
|
|
||||||
provider_type=agent_tool_entity.provider_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get decrypted parameters
|
app_model = app_service.get_app(app_model)
|
||||||
if agent_tool_entity.tool_parameters:
|
|
||||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
|
||||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
|
||||||
else:
|
|
||||||
masked_parameter = {}
|
|
||||||
|
|
||||||
# override tool parameters
|
|
||||||
tool['tool_parameters'] = masked_parameter
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# override agent mode
|
|
||||||
model_config.agent_mode = json.dumps(agent_mode)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
|
|||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
|
|||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
|
identity_id=f'AGENT.{app_model.id}'
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue
|
continue
|
||||||
@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
|
|||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
|
|||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
|
identity_id=f'AGENT.{app_model.id}'
|
||||||
)
|
)
|
||||||
manager.delete_tool_parameters_cache()
|
manager.delete_tool_parameters_cache()
|
||||||
|
|
||||||
@ -112,8 +116,10 @@ class ModelConfigResource(Resource):
|
|||||||
if key not in masked_parameter_map:
|
if key not in masked_parameter_map:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
if masked_key in agent_tool_entity.tool_parameters and \
|
||||||
|
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||||
|
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||||
|
|
||||||
# encrypt parameters
|
# encrypt parameters
|
||||||
if agent_tool_entity.tool_parameters:
|
if agent_tool_entity.tool_parameters:
|
||||||
|
@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
"""
|
"""
|
||||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
app_id=self.app_config.app_id,
|
||||||
agent_tool=tool,
|
agent_tool=tool,
|
||||||
)
|
)
|
||||||
tool_entity.load_variables(self.variables_pool)
|
tool_entity.load_variables(self.variables_pool)
|
||||||
|
@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
|
|||||||
|
|
||||||
class ToolParameterCache:
|
class ToolParameterCache:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
cache_type: ToolParameterCacheType
|
cache_type: ToolParameterCacheType,
|
||||||
|
identity_id: str
|
||||||
):
|
):
|
||||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
def get(self) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
|
@ -222,7 +222,7 @@ class ToolManager:
|
|||||||
return parameter_value
|
return parameter_value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
|
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the agent tool runtime
|
get the agent tool runtime
|
||||||
"""
|
"""
|
||||||
@ -245,6 +245,7 @@ class ToolManager:
|
|||||||
tool_runtime=tool_entity,
|
tool_runtime=tool_entity,
|
||||||
provider_name=agent_tool.provider_id,
|
provider_name=agent_tool.provider_id,
|
||||||
provider_type=agent_tool.provider_type,
|
provider_type=agent_tool.provider_type,
|
||||||
|
identity_id=f'AGENT.{app_id}'
|
||||||
)
|
)
|
||||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||||
|
|
||||||
@ -252,7 +253,7 @@ class ToolManager:
|
|||||||
return tool_entity
|
return tool_entity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
|
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
|
||||||
"""
|
"""
|
||||||
get the workflow tool runtime
|
get the workflow tool runtime
|
||||||
"""
|
"""
|
||||||
@ -277,6 +278,7 @@ class ToolManager:
|
|||||||
tool_runtime=tool_entity,
|
tool_runtime=tool_entity,
|
||||||
provider_name=workflow_tool.provider_id,
|
provider_name=workflow_tool.provider_id,
|
||||||
provider_type=workflow_tool.provider_type,
|
provider_type=workflow_tool.provider_type,
|
||||||
|
identity_id=f'WORKFLOW.{app_id}.{node_id}'
|
||||||
)
|
)
|
||||||
|
|
||||||
if runtime_parameters:
|
if runtime_parameters:
|
||||||
|
@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
|
|||||||
tool_runtime: Tool
|
tool_runtime: Tool
|
||||||
provider_name: str
|
provider_name: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
|
identity_id: str
|
||||||
|
|
||||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
deep copy parameters
|
deep copy parameters
|
||||||
"""
|
"""
|
||||||
return {key: value for key, value in parameters.items()}
|
return deepcopy(parameters)
|
||||||
|
|
||||||
def _merge_parameters(self) -> list[ToolParameter]:
|
def _merge_parameters(self) -> list[ToolParameter]:
|
||||||
"""
|
"""
|
||||||
@ -176,6 +177,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
|||||||
# override parameters
|
# override parameters
|
||||||
current_parameters = self._merge_parameters()
|
current_parameters = self._merge_parameters()
|
||||||
|
|
||||||
|
parameters = self._deep_copy(parameters)
|
||||||
|
|
||||||
for parameter in current_parameters:
|
for parameter in current_parameters:
|
||||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||||
if parameter.name in parameters:
|
if parameter.name in parameters:
|
||||||
@ -194,7 +197,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
|||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
provider=f'{self.provider_type}.{self.provider_name}',
|
provider=f'{self.provider_type}.{self.provider_name}',
|
||||||
tool_name=self.tool_runtime.identity.name,
|
tool_name=self.tool_runtime.identity.name,
|
||||||
cache_type=ToolParameterCacheType.PARAMETER
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id
|
||||||
)
|
)
|
||||||
cached_parameters = cache.get()
|
cached_parameters = cache.get()
|
||||||
if cached_parameters:
|
if cached_parameters:
|
||||||
@ -223,7 +227,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
|||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
provider=f'{self.provider_type}.{self.provider_name}',
|
provider=f'{self.provider_type}.{self.provider_name}',
|
||||||
tool_name=self.tool_runtime.identity.name,
|
tool_name=self.tool_runtime.identity.name,
|
||||||
cache_type=ToolParameterCacheType.PARAMETER
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id
|
||||||
)
|
)
|
||||||
cache.delete()
|
cache.delete()
|
||||||
|
|
||||||
|
@ -39,7 +39,8 @@ class ToolNode(BaseNode):
|
|||||||
parameters = self._generate_parameters(variable_pool, node_data)
|
parameters = self._generate_parameters(variable_pool, node_data)
|
||||||
# get tool runtime
|
# get tool runtime
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
|
self.app_id
|
||||||
|
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
@ -22,5 +22,6 @@ def handle(sender, **kwargs):
|
|||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=tool_entity.provider_name,
|
provider_name=tool_entity.provider_name,
|
||||||
provider_type=tool_entity.provider_type,
|
provider_type=tool_entity.provider_type,
|
||||||
|
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
|
||||||
)
|
)
|
||||||
manager.delete_tool_parameters_cache()
|
manager.delete_tool_parameters_cache()
|
||||||
|
@ -5,13 +5,17 @@ from typing import cast
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from flask_login import current_user
|
||||||
from flask_sqlalchemy.pagination import Pagination
|
from flask_sqlalchemy.pagination import Pagination
|
||||||
|
|
||||||
from constants.model_template import default_app_templates
|
from constants.model_template import default_app_templates
|
||||||
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.tools.tool_manager import ToolManager
|
||||||
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
|
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
@ -240,6 +244,64 @@ class AppService:
|
|||||||
|
|
||||||
return yaml.dump(export_data)
|
return yaml.dump(export_data)
|
||||||
|
|
||||||
|
def get_app(self, app: App) -> App:
|
||||||
|
"""
|
||||||
|
Get App
|
||||||
|
"""
|
||||||
|
# get original app model config
|
||||||
|
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||||
|
model_config: AppModelConfig = app.app_model_config
|
||||||
|
agent_mode = model_config.agent_mode_dict
|
||||||
|
# decrypt agent tool parameters if it's secret-input
|
||||||
|
for tool in agent_mode.get('tools') or []:
|
||||||
|
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||||
|
continue
|
||||||
|
agent_tool_entity = AgentToolEntity(**tool)
|
||||||
|
# get tool
|
||||||
|
try:
|
||||||
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
app_id=app.id,
|
||||||
|
agent_tool=agent_tool_entity,
|
||||||
|
)
|
||||||
|
manager = ToolParameterConfigurationManager(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
tool_runtime=tool_runtime,
|
||||||
|
provider_name=agent_tool_entity.provider_id,
|
||||||
|
provider_type=agent_tool_entity.provider_type,
|
||||||
|
identity_id=f'AGENT.{app.id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# get decrypted parameters
|
||||||
|
if agent_tool_entity.tool_parameters:
|
||||||
|
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||||
|
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||||
|
else:
|
||||||
|
masked_parameter = {}
|
||||||
|
|
||||||
|
# override tool parameters
|
||||||
|
tool['tool_parameters'] = masked_parameter
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# override agent mode
|
||||||
|
model_config.agent_mode = json.dumps(agent_mode)
|
||||||
|
|
||||||
|
class ModifiedApp(App):
|
||||||
|
"""
|
||||||
|
Modified App class
|
||||||
|
"""
|
||||||
|
def __init__(self, app):
|
||||||
|
self.__dict__.update(app.__dict__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_model_config(self):
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
app = ModifiedApp(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
def update_app(self, app: App, args: dict) -> App:
|
def update_app(self, app: App, args: dict) -> App:
|
||||||
"""
|
"""
|
||||||
Update app
|
Update app
|
||||||
|
Loading…
x
Reference in New Issue
Block a user