mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 03:35:56 +08:00
feat: support pinning, including, and excluding for model providers and tools (#7419)
Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
parent
6c25d7bed3
commit
4e7b6aec3a
@ -268,3 +268,12 @@ APP_MAX_ACTIVE_REQUESTS=0
|
|||||||
|
|
||||||
# Celery beat configuration
|
# Celery beat configuration
|
||||||
CELERY_BEAT_SCHEDULER_TIME=1
|
CELERY_BEAT_SCHEDULER_TIME=1
|
||||||
|
|
||||||
|
# Position configuration
|
||||||
|
POSITION_TOOL_PINS=
|
||||||
|
POSITION_TOOL_INCLUDES=
|
||||||
|
POSITION_TOOL_EXCLUDES=
|
||||||
|
|
||||||
|
POSITION_PROVIDER_PINS=
|
||||||
|
POSITION_PROVIDER_INCLUDES=
|
||||||
|
POSITION_PROVIDER_EXCLUDES=
|
||||||
|
@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseSettings):
|
class WorkspaceConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Workspace configs
|
Workspace configs
|
||||||
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionConfig(BaseSettings):
|
||||||
|
|
||||||
|
POSITION_PROVIDER_PINS: str = Field(
|
||||||
|
description='The heads of model providers',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||||
|
description='The included model providers',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||||
|
description='The excluded model providers',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
POSITION_TOOL_PINS: str = Field(
|
||||||
|
description='The heads of tools',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
POSITION_TOOL_INCLUDES: str = Field(
|
||||||
|
description='The included tools',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
POSITION_TOOL_EXCLUDES: str = Field(
|
||||||
|
description='The excluded tools',
|
||||||
|
default='',
|
||||||
|
)
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||||
|
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||||
|
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||||
|
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||||
|
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||||
|
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||||
|
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||||
|
|
||||||
|
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
@ -466,6 +524,7 @@ class FeatureConfig(
|
|||||||
UpdateConfig,
|
UpdateConfig,
|
||||||
WorkflowConfig,
|
WorkflowConfig,
|
||||||
WorkspaceConfig,
|
WorkspaceConfig,
|
||||||
|
PositionConfig,
|
||||||
|
|
||||||
# hosted services config
|
# hosted services config
|
||||||
HostedServiceConfig,
|
HostedServiceConfig,
|
||||||
|
@ -3,6 +3,7 @@ from collections import OrderedDict
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
|||||||
return {name: index for index, name in enumerate(positions)}
|
return {name: index for index, name in enumerate(positions)}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get the mapping for tools from name to index from a YAML file.
|
||||||
|
:param folder_path:
|
||||||
|
:param file_name: the YAML file name, default to '_position.yaml'
|
||||||
|
:return: a dict with name as key and index as value
|
||||||
|
"""
|
||||||
|
position_map = get_position_map(folder_path, file_name=file_name)
|
||||||
|
|
||||||
|
return pin_position_map(
|
||||||
|
position_map,
|
||||||
|
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get the mapping for providers from name to index from a YAML file.
|
||||||
|
:param folder_path:
|
||||||
|
:param file_name: the YAML file name, default to '_position.yaml'
|
||||||
|
:return: a dict with name as key and index as value
|
||||||
|
"""
|
||||||
|
position_map = get_position_map(folder_path, file_name=file_name)
|
||||||
|
return pin_position_map(
|
||||||
|
position_map,
|
||||||
|
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Pin the items in the pin list to the beginning of the position map.
|
||||||
|
Overall logic: exclude > include > pin
|
||||||
|
:param position_map: the position map to be sorted and filtered
|
||||||
|
:param pin_list: the list of pins to be put at the beginning
|
||||||
|
:return: the sorted position map
|
||||||
|
"""
|
||||||
|
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
|
||||||
|
|
||||||
|
# Add pins to position map
|
||||||
|
position_map = {name: idx for idx, name in enumerate(pin_list)}
|
||||||
|
|
||||||
|
# Add remaining positions to position map
|
||||||
|
start_idx = len(position_map)
|
||||||
|
for name in positions:
|
||||||
|
if name not in position_map:
|
||||||
|
position_map[name] = start_idx
|
||||||
|
start_idx += 1
|
||||||
|
|
||||||
|
return position_map
|
||||||
|
|
||||||
|
|
||||||
|
def is_filtered(
|
||||||
|
include_set: set[str],
|
||||||
|
exclude_set: set[str],
|
||||||
|
data: Any,
|
||||||
|
name_func: Callable[[Any], str],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Chcek if the object should be filtered out.
|
||||||
|
Overall logic: exclude > include > pin
|
||||||
|
:param include_set: the set of names to be included
|
||||||
|
:param exclude_set: the set of names to be excluded
|
||||||
|
:param name_func: the function to get the name of the object
|
||||||
|
:param data: the data to be filtered
|
||||||
|
:return: True if the object should be filtered out, False otherwise
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
if not include_set and not exclude_set:
|
||||||
|
return False
|
||||||
|
|
||||||
|
name = name_func(data)
|
||||||
|
|
||||||
|
if name in exclude_set: # exclude_set is prioritized
|
||||||
|
return True
|
||||||
|
if include_set and name not in include_set: # filter out only if include_set is not empty
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def sort_by_position_map(
|
def sort_by_position_map(
|
||||||
position_map: dict[str, int],
|
position_map: dict[str, int],
|
||||||
data: list[Any],
|
data: list[Any],
|
||||||
|
@ -368,6 +368,15 @@ class ModelManager:
|
|||||||
|
|
||||||
return ModelInstance(provider_model_bundle, model)
|
return ModelInstance(provider_model_bundle, model)
|
||||||
|
|
||||||
|
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Return first provider and the first model in the provider
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param model_type: model type
|
||||||
|
:return: provider name, model name
|
||||||
|
"""
|
||||||
|
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
|
||||||
|
|
||||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||||
"""
|
"""
|
||||||
Get default model instance
|
Get default model instance
|
||||||
@ -502,7 +511,6 @@ class LBModelManager:
|
|||||||
config.id
|
config.id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
res = redis_client.exists(cooldown_cache_key)
|
res = redis_client.exists(cooldown_cache_key)
|
||||||
res = cast(bool, res)
|
res = cast(bool, res)
|
||||||
return res
|
return res
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
@ -234,7 +234,7 @@ class ModelProviderFactory:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
position_map = get_position_map(model_providers_path)
|
position_map = get_provider_position_map(model_providers_path)
|
||||||
|
|
||||||
# traverse all model_provider_dir_paths
|
# traverse all model_provider_dir_paths
|
||||||
model_providers: list[ModelProviderExtension] = []
|
model_providers: list[ModelProviderExtension] = []
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||||
from core.entities.provider_entities import (
|
from core.entities.provider_entities import (
|
||||||
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
|
|||||||
)
|
)
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.provider_entities import (
|
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||||
CredentialFormSchema,
|
|
||||||
FormType,
|
|
||||||
ProviderEntity,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
from extensions import ext_hosting_provider
|
from extensions import ext_hosting_provider
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -45,6 +43,7 @@ class ProviderManager:
|
|||||||
"""
|
"""
|
||||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.decoding_rsa_key = None
|
self.decoding_rsa_key = None
|
||||||
self.decoding_cipher_rsa = None
|
self.decoding_cipher_rsa = None
|
||||||
@ -117,6 +116,16 @@ class ProviderManager:
|
|||||||
|
|
||||||
# Construct ProviderConfiguration objects for each provider
|
# Construct ProviderConfiguration objects for each provider
|
||||||
for provider_entity in provider_entities:
|
for provider_entity in provider_entities:
|
||||||
|
|
||||||
|
# handle include, exclude
|
||||||
|
if is_filtered(
|
||||||
|
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||||
|
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||||
|
data=provider_entity,
|
||||||
|
name_func=lambda x: x.provider,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
provider_name = provider_entity.provider
|
provider_name = provider_entity.provider
|
||||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||||
@ -271,6 +280,24 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Get names of first model and its provider
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
|
:param model_type: model type
|
||||||
|
:return: provider name, model name
|
||||||
|
"""
|
||||||
|
provider_configurations = self.get_configurations(tenant_id)
|
||||||
|
|
||||||
|
# get available models from provider_configurations
|
||||||
|
all_models = provider_configurations.get_models(
|
||||||
|
model_type=model_type,
|
||||||
|
only_active=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_models[0].provider.provider, all_models[0].model
|
||||||
|
|
||||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||||
-> TenantDefaultModel:
|
-> TenantDefaultModel:
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||||
from core.tools.entities.api_entities import UserToolProvider
|
from core.tools.entities.api_entities import UserToolProvider
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class BuiltinToolProviderSort:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||||
if not cls._position:
|
if not cls._position:
|
||||||
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
def name_func(provider: UserToolProvider) -> str:
|
def name_func(provider: UserToolProvider) -> str:
|
||||||
return provider.name
|
return provider.name
|
||||||
|
@ -10,14 +10,11 @@ from configs import dify_config
|
|||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||||
ApiProviderAuthType,
|
|
||||||
ToolInvokeFrom,
|
|
||||||
ToolParameter,
|
|
||||||
)
|
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||||
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
|
|||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import (
|
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||||
ToolConfigurationManager,
|
|
||||||
ToolParameterConfigurationManager,
|
|
||||||
)
|
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||||
from core.workflow.nodes.tool.entities import ToolEntity
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
_builtin_provider_lock = Lock()
|
_builtin_provider_lock = Lock()
|
||||||
_builtin_providers = {}
|
_builtin_providers = {}
|
||||||
@ -414,6 +409,15 @@ class ToolManager:
|
|||||||
|
|
||||||
# append builtin providers
|
# append builtin providers
|
||||||
for provider in builtin_providers:
|
for provider in builtin_providers:
|
||||||
|
# handle include, exclude
|
||||||
|
if is_filtered(
|
||||||
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||||
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||||
|
data=provider,
|
||||||
|
name_func=lambda x: x.identity.name
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
provider_controller=provider,
|
provider_controller=provider,
|
||||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||||
@ -593,4 +597,5 @@ class ToolManager:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"provider type {provider_type} not found")
|
raise ValueError(f"provider type {provider_type} not found")
|
||||||
|
|
||||||
|
|
||||||
ToolManager.load_builtin_providers_cache()
|
ToolManager.load_builtin_providers_cache()
|
||||||
|
@ -111,6 +111,12 @@ class AppService:
|
|||||||
'completion_params': {}
|
'completion_params': {}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
provider, model = model_manager.get_default_provider_model_name(
|
||||||
|
tenant_id=account.current_tenant_id,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
default_model_config['model']['provider'] = provider
|
||||||
|
default_model_config['model']['name'] = model
|
||||||
default_model_dict = default_model_config['model']
|
default_model_dict = default_model_config['model']
|
||||||
|
|
||||||
default_model_config['model'] = json.dumps(default_model_dict)
|
default_model_config['model'] = json.dumps(default_model_dict)
|
||||||
@ -190,6 +196,7 @@ class AppService:
|
|||||||
"""
|
"""
|
||||||
Modified App class
|
Modified App class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.__dict__.update(app.__dict__)
|
self.__dict__.update(app.__dict__)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ class ModelProviderService:
|
|||||||
"""
|
"""
|
||||||
Model Provider Service
|
Model Provider Service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.provider_manager = ProviderManager()
|
self.provider_manager = ProviderManager()
|
||||||
|
|
||||||
@ -387,7 +388,7 @@ class ModelProviderService:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=model_type_enum
|
model_type=model_type_enum
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
return DefaultModelResponse(
|
return DefaultModelResponse(
|
||||||
model=result.model,
|
model=result.model,
|
||||||
model_type=result.model_type,
|
model_type=result.model_type,
|
||||||
@ -399,6 +400,9 @@ class ModelProviderService:
|
|||||||
supported_model_types=result.provider.supported_model_types
|
supported_model_types=result.provider.supported_model_types
|
||||||
)
|
)
|
||||||
) if result else None
|
) if result else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||||
@ -119,7 +121,7 @@ class BuiltinToolManageService:
|
|||||||
# delete cache
|
# delete cache
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
tool_configuration.delete_tool_credentials_cache()
|
||||||
|
|
||||||
return { 'result': 'success' }
|
return {'result': 'success'}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_credentials(
|
def get_builtin_tool_provider_credentials(
|
||||||
@ -165,7 +167,7 @@ class BuiltinToolManageService:
|
|||||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||||
tool_configuration.delete_tool_credentials_cache()
|
tool_configuration.delete_tool_credentials_cache()
|
||||||
|
|
||||||
return { 'result': 'success' }
|
return {'result': 'success'}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_tool_provider_icon(
|
def get_builtin_tool_provider_icon(
|
||||||
@ -202,6 +204,15 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
for provider_controller in provider_controllers:
|
for provider_controller in provider_controllers:
|
||||||
try:
|
try:
|
||||||
|
# handle include, exclude
|
||||||
|
if is_filtered(
|
||||||
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||||
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||||
|
data=provider_controller,
|
||||||
|
name_func=lambda x: x.identity.name
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# convert provider controller to user provider
|
# convert provider controller to user provider
|
||||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||||
provider_controller=provider_controller,
|
provider_controller=provider_controller,
|
||||||
@ -226,4 +237,3 @@ class BuiltinToolManageService:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
return BuiltinToolProviderSort.sort(result)
|
return BuiltinToolProviderSort.sort(result)
|
||||||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.helper.position_helper import get_position_map
|
from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya
|
|||||||
folder_path=prepare_empty_commented_positions_yaml,
|
folder_path=prepare_empty_commented_positions_yaml,
|
||||||
file_name='example_positions_all_commented.yaml')
|
file_name='example_positions_all_commented.yaml')
|
||||||
assert position_map == {}
|
assert position_map == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_excluded_position_data(prepare_example_positions_yaml):
|
||||||
|
position_map = get_position_map(
|
||||||
|
folder_path=prepare_example_positions_yaml,
|
||||||
|
file_name='example_positions.yaml'
|
||||||
|
)
|
||||||
|
pin_list = ['forth', 'first']
|
||||||
|
include_set = set()
|
||||||
|
exclude_set = {'9999999999999'}
|
||||||
|
|
||||||
|
position_map = pin_position_map(
|
||||||
|
original_position_map=position_map,
|
||||||
|
pin_list=pin_list
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
"forth",
|
||||||
|
"first",
|
||||||
|
"second",
|
||||||
|
"third",
|
||||||
|
"9999999999999",
|
||||||
|
"extra1",
|
||||||
|
"extra2",
|
||||||
|
]
|
||||||
|
|
||||||
|
# filter out the data
|
||||||
|
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||||
|
|
||||||
|
# sort data by position map
|
||||||
|
sorted_data = sort_by_position_map(
|
||||||
|
position_map=position_map,
|
||||||
|
data=data,
|
||||||
|
name_func=lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert the result in the correct order
|
||||||
|
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
|
||||||
|
|
||||||
|
|
||||||
|
def test_included_position_data(prepare_example_positions_yaml):
|
||||||
|
position_map = get_position_map(
|
||||||
|
folder_path=prepare_example_positions_yaml,
|
||||||
|
file_name='example_positions.yaml'
|
||||||
|
)
|
||||||
|
pin_list = ['forth', 'first']
|
||||||
|
include_set = {'forth', 'first'}
|
||||||
|
exclude_set = {}
|
||||||
|
|
||||||
|
position_map = pin_position_map(
|
||||||
|
original_position_map=position_map,
|
||||||
|
pin_list=pin_list
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
"forth",
|
||||||
|
"first",
|
||||||
|
"second",
|
||||||
|
"third",
|
||||||
|
"9999999999999",
|
||||||
|
"extra1",
|
||||||
|
"extra2",
|
||||||
|
]
|
||||||
|
|
||||||
|
# filter out the data
|
||||||
|
data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)]
|
||||||
|
|
||||||
|
# sort data by position map
|
||||||
|
sorted_data = sort_by_position_map(
|
||||||
|
position_map=position_map,
|
||||||
|
data=data,
|
||||||
|
name_func=lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert the result in the correct order
|
||||||
|
assert sorted_data == ['forth', 'first']
|
||||||
|
@ -701,3 +701,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate}
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
EXPOSE_NGINX_PORT=80
|
EXPOSE_NGINX_PORT=80
|
||||||
EXPOSE_NGINX_SSL_PORT=443
|
EXPOSE_NGINX_SSL_PORT=443
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# ModelProvider & Tool Position Configuration
|
||||||
|
# Used to specify the model providers and tools that can be used in the app.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Pin, include, and exclude tools
|
||||||
|
# Use comma-separated values with no spaces between items.
|
||||||
|
# Example: POSITION_TOOL_PINS=bing,google
|
||||||
|
POSITION_TOOL_PINS=
|
||||||
|
POSITION_TOOL_INCLUDES=
|
||||||
|
POSITION_TOOL_EXCLUDES=
|
||||||
|
|
||||||
|
# Pin, include, and exclude model providers
|
||||||
|
# Use comma-separated values with no spaces between items.
|
||||||
|
# Example: POSITION_PROVIDER_PINS=openai,openllm
|
||||||
|
POSITION_PROVIDER_PINS=
|
||||||
|
POSITION_PROVIDER_INCLUDES=
|
||||||
|
POSITION_PROVIDER_EXCLUDES=
|
Loading…
x
Reference in New Issue
Block a user