diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 049c9d0622..b739016559 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -148,7 +148,9 @@ class AIModel(ABC): position_map = {} if os.path.exists(position_file_path): with open(position_file_path, 'r', encoding='utf-8') as f: - position_map = yaml.safe_load(f) + positions = yaml.safe_load(f) + # convert list to dict with key as model provider name, value as index + position_map = {position: index for index, position in enumerate(positions)} # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index da9f180e62..6253498b21 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -1,19 +1,20 @@ -openai: 0 -anthropic: 1 -azure_openai: 2 -google: 3 -replicate: 4 -huggingface_hub: 5 -cohere: 6 -zhipuai: 7 -baichuan: 8 -spark: 9 -minimax: 10 -tongyi: 11 -wenxin: 12 -jina: 13 -chatglm: 14 -xinference: 15 -openllm: 16 -localai: 17 -openai_api_compatible: 18 \ No newline at end of file +- openai +- anthropic +- azure_openai +- google +- replicate +- huggingface_hub +- cohere +- togetherai +- zhipuai +- baichuan +- spark +- minimax +- tongyi +- wenxin +- jina +- chatglm +- xinference +- openllm +- localai +- openai_api_compatible \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 740425d43f..1435ab89b1 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -217,7 +217,9 @@ class ModelProviderFactory: position_map = {} if os.path.exists(position_file_path): with open(position_file_path, 'r', encoding='utf-8') as f: - position_map = yaml.safe_load(f) + positions = yaml.safe_load(f) + # convert list to dict with key as model provider name, value as index + position_map = {position: index for index, position in enumerate(positions)} # traverse all model_provider_dir_paths for model_provider_dir_path in model_provider_dir_paths: diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index a131cb4672..4f69acb30b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,9 +1,11 @@ -gpt-4: 0 -gpt-4-32k: 1 -gpt-4-1106-preview: 2 -gpt-4-vision-preview: 3 -gpt-3.5-turbo: 4 -gpt-3.5-turbo-16k: 5 -gpt-3.5-turbo-1106: 6 -gpt-3.5-turbo-instruct: 7 -text-davinci-003: 8 \ No newline at end of file +- gpt-4 +- gpt-4-32k +- gpt-4-1106-preview +- gpt-4-vision-preview +- gpt-3.5-turbo +- gpt-3.5-turbo-16k +- gpt-3.5-turbo-16k-0613 +- gpt-3.5-turbo-1106 +- gpt-3.5-turbo-0613 +- gpt-3.5-turbo-instruct +- text-davinci-003 \ No newline at end of file