From 4e7b6aec3a7ab0c67777c2e44eede7dc500b7391 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Tue, 20 Aug 2024 23:16:43 -0400 Subject: [PATCH] feat: support pinning, including, and excluding for model providers and tools (#7419) Co-authored-by: GareArc --- api/.env.example | 11 ++- api/configs/feature/__init__.py | 59 +++++++++++++ api/core/helper/position_helper.py | 82 ++++++++++++++++++ api/core/model_manager.py | 10 ++- .../model_providers/__base/ai_model.py | 6 +- .../model_providers/model_provider_factory.py | 4 +- api/core/provider_manager.py | 37 ++++++-- api/core/tools/provider/builtin/_positions.py | 6 +- api/core/tools/tool_manager.py | 29 ++++--- api/services/app_service.py | 9 +- api/services/model_provider_service.py | 28 +++--- .../tools/builtin_tools_manage_service.py | 34 +++++--- .../position_helper/test_position_helper.py | 86 +++++++++++++++++-- docker/.env.example | 19 ++++ 14 files changed, 363 insertions(+), 57 deletions(-) diff --git a/api/.env.example b/api/.env.example index 775149f8fd..f81675fd53 100644 --- a/api/.env.example +++ b/api/.env.example @@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788..27de2c0461 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings): default=False, ) + class WorkspaceConfig(BaseSettings): """ Workspace configs @@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings): ) +class PositionConfig(BaseSettings): + + POSITION_PROVIDER_PINS: str = Field( + description='The heads of model providers', + default='', + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description='The included model providers', + default='', + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description='The excluded model providers', + default='', + ) + + POSITION_TOOL_PINS: str = Field( + description='The heads of tools', + default='', + ) + + POSITION_TOOL_INCLUDES: str = Field( + description='The included tools', + default='', + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description='The excluded tools', + default='', + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''} + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -466,6 +524,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, + PositionConfig, # hosted services config HostedServiceConfig, diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791..8cf184ac44 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,6 +3,7 @@ from collections import OrderedDict from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file @@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Chcek if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + def sort_by_position_map( position_map: dict[str, int], data: list[Any], diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7c23e14297..1ceed8043c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -368,6 +368,15 @@ class ModelManager: return ModelInstance(provider_model_bundle, model) + def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Return first provider and the first model in the provider + :param tenant_id: tenant id + :param model_type: model type + :return: provider name, model name + """ + return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) + def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: """ Get default model instance @@ -502,7 +511,6 @@ class LBModelManager: config.id ) - res = redis_client.exists(cooldown_cache_key) res = cast(bool, res) return res diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf89..716bb63566 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -151,9 +151,9 @@ class AIModel(ABC): os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and not model_schema_yaml.startswith('_') + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith('.yaml') ] # get _position.yaml file path diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb..e2d17e3257 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ class ModelProviderFactory: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 65a4cada88..6c68cee7be 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -5,6 +5,7 @@ from typing import Optional from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -18,12 +19,9 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.helper.position_helper import is_filtered from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -45,6 +43,7 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: self.decoding_rsa_key = None self.decoding_cipher_rsa = None @@ -117,6 +116,16 @@ class ProviderManager: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: + + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue + provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) @@ -271,6 +280,24 @@ class ProviderManager: ) ) + def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Get names of first model and its provider + + :param tenant_id: workspace id + :param model_type: model type + :return: provider name, model name + """ + provider_configurations = self.get_configurations(tenant_id) + + # get available models from provider_configurations + all_models = provider_configurations.get_models( + model_type=model_type, + only_active=False + ) + + return all_models[0].provider.provider, all_models[0].model + def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ -> TenantDefaultModel: """ diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4..062668fc5b 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d7ddb40e6b..4a0188af49 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,14 +10,11 @@ from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, -) +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort @@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolConfigurationManager, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) + class ToolManager: _builtin_provider_lock = Lock() _builtin_providers = {} @@ -107,7 +102,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -346,7 +341,7 @@ class ToolManager: provider_class = load_single_subclass_from_source( module_name=f'core.tools.provider.builtin.{provider}.{provider}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider @@ -414,6 +409,15 @@ class ToolManager: # append builtin providers for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name + ): + continue + user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), @@ -473,7 +477,7 @@ class ToolManager: @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -593,4 +597,5 @@ class ToolManager: else: raise ValueError(f"provider type {provider_type} not found") + ToolManager.load_builtin_providers_cache() diff --git a/api/services/app_service.py b/api/services/app_service.py index f33ef9b001..93f7169c1f 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -111,6 +111,12 @@ class AppService: 'completion_params': {} } else: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id, + model_type=ModelType.LLM + ) + default_model_config['model']['provider'] = provider + default_model_config['model']['name'] = model default_model_dict = default_model_config['model'] default_model_config['model'] = json.dumps(default_model_dict) @@ -190,13 +196,14 @@ class AppService: """ Modified App class """ + def __init__(self, app): self.__dict__.update(app.__dict__) @property def app_model_config(self): return model_config - + app = ModifiedApp(app) return app diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 385af685f9..86c059fe90 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -30,6 +30,7 @@ class ModelProviderService: """ Model Provider Service """ + def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -387,18 +388,21 @@ class ModelProviderService: tenant_id=tenant_id, model_type=model_type_enum ) - - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types - ) - ) if result else None + try: + return DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types + ) + ) if result else None + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ea6ecf0c69..ebadbd9be0 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,8 @@ import json import logging +from configs import dify_config +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError @@ -43,14 +45,14 @@ class BuiltinToolManageService: result = [] for tool in tools: result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, + tool=tool, + credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller) )) return result - + @staticmethod def list_builtin_provider_credentials_schema( provider_name @@ -78,7 +80,7 @@ class BuiltinToolManageService: BuiltinToolProvider.provider == provider_name, ).first() - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: @@ -119,8 +121,8 @@ class BuiltinToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {'result': 'success'} + @staticmethod def get_builtin_tool_provider_credentials( user_id: str, tenant_id: str, provider: str @@ -135,7 +137,7 @@ class BuiltinToolManageService: if provider is None: return {} - + provider_controller = ToolManager.get_builtin_provider(provider.provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -156,7 +158,7 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f'you have not added provider {provider_name}') - + db.session.delete(provider) db.session.commit() @@ -165,8 +167,8 @@ class BuiltinToolManageService: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {'result': 'success'} + @staticmethod def get_builtin_tool_provider_icon( provider: str @@ -179,7 +181,7 @@ class BuiltinToolManageService: icon_bytes = f.read() return icon_bytes, mime_type - + @staticmethod def list_builtin_tools( user_id: str, tenant_id: str @@ -202,6 +204,15 @@ class BuiltinToolManageService: for provider_controller in provider_controllers: try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name + ): + continue + # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -226,4 +237,3 @@ class BuiltinToolManageService: raise e return BuiltinToolProviderSort.sort(result) - \ No newline at end of file diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 2237319904..1235e559c9 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,7 +2,7 @@ from textwrap import dedent import pytest -from core.helper.position_helper import get_position_map +from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map @pytest.fixture @@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: - second # - commented - third - + - 9999999999999 - forth """)) @@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: """\ # - commented1 # - commented2 - - - - - + - + - + """)) return str(tmp_path) @@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya folder_path=prepare_empty_commented_positions_yaml, file_name='example_positions_all_commented.yaml') assert position_map == {} + + +def test_excluded_position_data(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['forth', 'first'] + include_set = set() + exclude_set = {'9999999999999'} + + position_map = pin_position_map( + original_position_map=position_map, + pin_list=pin_list + ) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2'] + + +def test_included_position_data(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['forth', 'first'] + include_set = {'forth', 'first'} + exclude_set = {} + + position_map = pin_position_map( + original_position_map=position_map, + pin_list=pin_list + ) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ['forth', 'first'] diff --git a/docker/.env.example b/docker/.env.example index 03e1e4e50e..ed9621854a 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -701,3 +701,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} # ------------------------------ EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= \ No newline at end of file