mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 11:48:59 +08:00
generalize position helper for parsing _position.yaml and sorting objects by name (#2803)
This commit is contained in:
parent
849dc0560b
commit
8b15b742ad
@ -3,11 +3,12 @@ import importlib.util
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.utils.position_helper import sort_to_dict_by_position_map
|
||||||
|
|
||||||
|
|
||||||
class ExtensionModule(enum.Enum):
|
class ExtensionModule(enum.Enum):
|
||||||
MODERATION = 'moderation'
|
MODERATION = 'moderation'
|
||||||
@ -36,7 +37,8 @@ class Extensible:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def scan_extensions(cls):
|
def scan_extensions(cls):
|
||||||
extensions = {}
|
extensions: list[ModuleExtension] = []
|
||||||
|
position_map = {}
|
||||||
|
|
||||||
# get the path of the current class
|
# get the path of the current class
|
||||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||||
@ -63,6 +65,7 @@ class Extensible:
|
|||||||
if os.path.exists(builtin_file_path):
|
if os.path.exists(builtin_file_path):
|
||||||
with open(builtin_file_path, encoding='utf-8') as f:
|
with open(builtin_file_path, encoding='utf-8') as f:
|
||||||
position = int(f.read().strip())
|
position = int(f.read().strip())
|
||||||
|
position_map[extension_name] = position
|
||||||
|
|
||||||
if (extension_name + '.py') not in file_names:
|
if (extension_name + '.py') not in file_names:
|
||||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||||
@ -96,16 +99,15 @@ class Extensible:
|
|||||||
with open(json_path, encoding='utf-8') as f:
|
with open(json_path, encoding='utf-8') as f:
|
||||||
json_data = json.load(f)
|
json_data = json.load(f)
|
||||||
|
|
||||||
extensions[extension_name] = ModuleExtension(
|
extensions.append(ModuleExtension(
|
||||||
extension_class=extension_class,
|
extension_class=extension_class,
|
||||||
name=extension_name,
|
name=extension_name,
|
||||||
label=json_data.get('label'),
|
label=json_data.get('label'),
|
||||||
form_schema=json_data.get('form_schema'),
|
form_schema=json_data.get('form_schema'),
|
||||||
builtin=builtin,
|
builtin=builtin,
|
||||||
position=position
|
position=position
|
||||||
)
|
))
|
||||||
|
|
||||||
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
|
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
|
||||||
sorted_extensions = OrderedDict(sorted_items)
|
|
||||||
|
|
||||||
return sorted_extensions
|
return sorted_extensions
|
||||||
|
@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||||
|
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||||
|
|
||||||
|
|
||||||
class AIModel(ABC):
|
class AIModel(ABC):
|
||||||
@ -148,15 +149,7 @@ class AIModel(ABC):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')
|
position_map = get_position_map(provider_model_type_path)
|
||||||
|
|
||||||
# read _position.yaml file
|
|
||||||
position_map = {}
|
|
||||||
if os.path.exists(position_file_path):
|
|
||||||
with open(position_file_path, encoding='utf-8') as f:
|
|
||||||
positions = yaml.safe_load(f)
|
|
||||||
# convert list to dict with key as model provider name, value as index
|
|
||||||
position_map = {position: index for index, position in enumerate(positions)}
|
|
||||||
|
|
||||||
# traverse all model_schema_yaml_paths
|
# traverse all model_schema_yaml_paths
|
||||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||||
@ -206,8 +199,7 @@ class AIModel(ABC):
|
|||||||
model_schemas.append(model_schema)
|
model_schemas.append(model_schema)
|
||||||
|
|
||||||
# resort model schemas by position
|
# resort model schemas by position
|
||||||
if position_map:
|
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
|
||||||
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
|
|
||||||
|
|
||||||
# cache model schemas
|
# cache model schemas
|
||||||
self.model_schemas = model_schemas
|
self.model_schemas = model_schemas
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
|
|||||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||||
|
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -200,7 +199,6 @@ class ModelProviderFactory:
|
|||||||
if self.model_provider_extensions:
|
if self.model_provider_extensions:
|
||||||
return self.model_provider_extensions
|
return self.model_provider_extensions
|
||||||
|
|
||||||
model_providers = {}
|
|
||||||
|
|
||||||
# get the path of current classes
|
# get the path of current classes
|
||||||
current_path = os.path.abspath(__file__)
|
current_path = os.path.abspath(__file__)
|
||||||
@ -215,17 +213,10 @@ class ModelProviderFactory:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
position_file_path = os.path.join(model_providers_path, '_position.yaml')
|
position_map = get_position_map(model_providers_path)
|
||||||
|
|
||||||
# read _position.yaml file
|
|
||||||
position_map = {}
|
|
||||||
if os.path.exists(position_file_path):
|
|
||||||
with open(position_file_path, encoding='utf-8') as f:
|
|
||||||
positions = yaml.safe_load(f)
|
|
||||||
# convert list to dict with key as model provider name, value as index
|
|
||||||
position_map = {position: index for index, position in enumerate(positions)}
|
|
||||||
|
|
||||||
# traverse all model_provider_dir_paths
|
# traverse all model_provider_dir_paths
|
||||||
|
model_providers: list[ModelProviderExtension] = []
|
||||||
for model_provider_dir_path in model_provider_dir_paths:
|
for model_provider_dir_path in model_provider_dir_paths:
|
||||||
# get model_provider dir name
|
# get model_provider dir name
|
||||||
model_provider_name = os.path.basename(model_provider_dir_path)
|
model_provider_name = os.path.basename(model_provider_dir_path)
|
||||||
@ -256,14 +247,13 @@ class ModelProviderFactory:
|
|||||||
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_providers[model_provider_name] = ModelProviderExtension(
|
model_providers.append(ModelProviderExtension(
|
||||||
name=model_provider_name,
|
name=model_provider_name,
|
||||||
provider_instance=model_provider_class(),
|
provider_instance=model_provider_class(),
|
||||||
position=position_map.get(model_provider_name)
|
position=position_map.get(model_provider_name)
|
||||||
)
|
))
|
||||||
|
|
||||||
sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
|
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
||||||
sorted_extensions = OrderedDict(sorted_items)
|
|
||||||
|
|
||||||
self.model_provider_extensions = sorted_extensions
|
self.model_provider_extensions = sorted_extensions
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from yaml import FullLoader, load
|
|
||||||
|
|
||||||
from core.tools.entities.user_entities import UserToolProvider
|
from core.tools.entities.user_entities import UserToolProvider
|
||||||
|
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolProviderSort:
|
class BuiltinToolProviderSort:
|
||||||
@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||||
if not cls._position:
|
if not cls._position:
|
||||||
tmp_position = {}
|
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
|
|
||||||
with open(file_path) as f:
|
|
||||||
for pos, val in enumerate(load(f, Loader=FullLoader)):
|
|
||||||
tmp_position[val] = pos
|
|
||||||
cls._position = tmp_position
|
|
||||||
|
|
||||||
def sort_compare(provider: UserToolProvider) -> int:
|
def name_func(provider: UserToolProvider) -> str:
|
||||||
if provider.type == UserToolProvider.ProviderType.MODEL:
|
if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||||
return cls._position.get(f'model.{provider.name}', 10000)
|
return f'model.{provider.name}'
|
||||||
return cls._position.get(provider.name, 10000)
|
else:
|
||||||
|
return provider.name
|
||||||
sorted_providers = sorted(providers, key=sort_compare)
|
|
||||||
|
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||||
|
|
||||||
return sorted_providers
|
return sorted_providers
|
70
api/core/utils/position_helper.py
Normal file
70
api/core/utils/position_helper.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any, AnyStr
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def get_position_map(
|
||||||
|
folder_path: AnyStr,
|
||||||
|
file_name: str = '_position.yaml',
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get the mapping from name to index from a YAML file
|
||||||
|
:param folder_path:
|
||||||
|
:param file_name: the YAML file name, default to '_position.yaml'
|
||||||
|
:return: a dict with name as key and index as value
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
position_file_name = os.path.join(folder_path, file_name)
|
||||||
|
if not os.path.exists(position_file_name):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with open(position_file_name, encoding='utf-8') as f:
|
||||||
|
positions = yaml.safe_load(f)
|
||||||
|
position_map = {}
|
||||||
|
for index, name in enumerate(positions):
|
||||||
|
if name and isinstance(name, str):
|
||||||
|
position_map[name.strip()] = index
|
||||||
|
return position_map
|
||||||
|
except:
|
||||||
|
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def sort_by_position_map(
|
||||||
|
position_map: dict[str, int],
|
||||||
|
data: list[Any],
|
||||||
|
name_func: Callable[[Any], str],
|
||||||
|
) -> list[Any]:
|
||||||
|
"""
|
||||||
|
Sort the objects by the position map.
|
||||||
|
If the name of the object is not in the position map, it will be put at the end.
|
||||||
|
:param position_map: the map holding positions in the form of {name: index}
|
||||||
|
:param name_func: the function to get the name of the object
|
||||||
|
:param data: the data to be sorted
|
||||||
|
:return: the sorted objects
|
||||||
|
"""
|
||||||
|
if not position_map or not data:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
|
||||||
|
|
||||||
|
|
||||||
|
def sort_to_dict_by_position_map(
|
||||||
|
position_map: dict[str, int],
|
||||||
|
data: list[Any],
|
||||||
|
name_func: Callable[[Any], str],
|
||||||
|
) -> OrderedDict[str, Any]:
|
||||||
|
"""
|
||||||
|
Sort the objects into a ordered dict by the position map.
|
||||||
|
If the name of the object is not in the position map, it will be put at the end.
|
||||||
|
:param position_map: the map holding positions in the form of {name: index}
|
||||||
|
:param name_func: the function to get the name of the object
|
||||||
|
:param data: the data to be sorted
|
||||||
|
:return: an OrderedDict with the sorted pairs of name and object
|
||||||
|
"""
|
||||||
|
sorted_items = sort_by_position_map(position_map, data, name_func)
|
||||||
|
return OrderedDict([(name_func(item), item) for item in sorted_items])
|
Loading…
x
Reference in New Issue
Block a user