generalize position helper for parsing _position.yaml and sorting objects by name (#2803)

This commit is contained in:
Bowen Liang 2024-03-13 20:29:38 +08:00 committed by GitHub
parent 849dc0560b
commit 8b15b742ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 95 additions and 46 deletions

View File

@ -3,11 +3,12 @@ import importlib.util
import json import json
import logging import logging
import os import os
from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.utils.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum): class ExtensionModule(enum.Enum):
MODERATION = 'moderation' MODERATION = 'moderation'
@ -36,7 +37,8 @@ class Extensible:
@classmethod @classmethod
def scan_extensions(cls): def scan_extensions(cls):
extensions = {} extensions: list[ModuleExtension] = []
position_map = {}
# get the path of the current class # get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
@ -63,6 +65,7 @@ class Extensible:
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f: with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip()) position = int(f.read().strip())
position_map[extension_name] = position
if (extension_name + '.py') not in file_names: if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
@ -96,16 +99,15 @@ class Extensible:
with open(json_path, encoding='utf-8') as f: with open(json_path, encoding='utf-8') as f:
json_data = json.load(f) json_data = json.load(f)
extensions[extension_name] = ModuleExtension( extensions.append(ModuleExtension(
extension_class=extension_class, extension_class=extension_class,
name=extension_name, name=extension_name,
label=json_data.get('label'), label=json_data.get('label'),
form_schema=json_data.get('form_schema'), form_schema=json_data.get('form_schema'),
builtin=builtin, builtin=builtin,
position=position position=position
) ))
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
sorted_extensions = OrderedDict(sorted_items)
return sorted_extensions return sorted_extensions

View File

@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
) )
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map
class AIModel(ABC): class AIModel(ABC):
@ -148,15 +149,7 @@ class AIModel(ABC):
] ]
# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(provider_model_type_path, '_position.yaml') position_map = get_position_map(provider_model_type_path)
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_schema_yaml_paths # traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths: for model_schema_yaml_path in model_schema_yaml_paths:
@ -206,8 +199,7 @@ class AIModel(ABC):
model_schemas.append(model_schema) model_schemas.append(model_schema)
# resort model schemas by position # resort model schemas by position
if position_map: model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
# cache model schemas # cache model schemas
self.model_schemas = model_schemas self.model_schemas = model_schemas

View File

@ -1,10 +1,8 @@
import importlib import importlib
import logging import logging
import os import os
from collections import OrderedDict
from typing import Optional from typing import Optional
import yaml
from pydantic import BaseModel from pydantic import BaseModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -200,7 +199,6 @@ class ModelProviderFactory:
if self.model_provider_extensions: if self.model_provider_extensions:
return self.model_provider_extensions return self.model_provider_extensions
model_providers = {}
# get the path of current classes # get the path of current classes
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
@ -215,17 +213,10 @@ class ModelProviderFactory:
] ]
# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(model_providers_path, '_position.yaml') position_map = get_position_map(model_providers_path)
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_provider_dir_paths # traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths: for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name # get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path) model_provider_name = os.path.basename(model_provider_dir_path)
@ -256,14 +247,13 @@ class ModelProviderFactory:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue continue
model_providers[model_provider_name] = ModelProviderExtension( model_providers.append(ModelProviderExtension(
name=model_provider_name, name=model_provider_name,
provider_instance=model_provider_class(), provider_instance=model_provider_class(),
position=position_map.get(model_provider_name) position=position_map.get(model_provider_name)
) ))
sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position)) sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
sorted_extensions = OrderedDict(sorted_items)
self.model_provider_extensions = sorted_extensions self.model_provider_extensions = sorted_extensions

View File

@ -1,8 +1,7 @@
import os.path import os.path
from yaml import FullLoader, load
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map
class BuiltinToolProviderSort: class BuiltinToolProviderSort:
@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
@classmethod @classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position: if not cls._position:
tmp_position = {} cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
cls._position = tmp_position
def sort_compare(provider: UserToolProvider) -> int: def name_func(provider: UserToolProvider) -> str:
if provider.type == UserToolProvider.ProviderType.MODEL: if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000) return f'model.{provider.name}'
return cls._position.get(provider.name, 10000) else:
return provider.name
sorted_providers = sorted(providers, key=sort_compare)
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers return sorted_providers

View File

@ -0,0 +1,70 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
import yaml
def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}
with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: the sorted objects
"""
if not position_map or not data:
return data
return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])