diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index bdc08ae6dd..749ecd6422 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -21,8 +21,12 @@ class ModelProviderListApi(Resource): def get(self): tenant_id = current_user.current_tenant_id + parser = reqparse.RequestParser() + parser.add_argument('model_type', type=str, required=False, nullable=True, location='args') + args = parser.parse_args() + provider_service = ProviderService() - provider_list = provider_service.get_provider_list(tenant_id) + provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type')) return provider_list diff --git a/api/core/model_providers/rules/anthropic.json b/api/core/model_providers/rules/anthropic.json index e617842b94..a302b1de13 100644 --- a/api/core/model_providers/rules/anthropic.json +++ b/api/core/model_providers/rules/anthropic.json @@ -12,6 +12,9 @@ "quota_limit": 0 }, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ], "price_config": { "claude-instant-1": { "prompt": "1.63", diff --git a/api/core/model_providers/rules/azure_openai.json b/api/core/model_providers/rules/azure_openai.json index fe4dc10c56..05a8007855 100644 --- a/api/core/model_providers/rules/azure_openai.json +++ b/api/core/model_providers/rules/azure_openai.json @@ -4,6 +4,10 @@ ], "system_config": null, "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ], "price_config":{ "gpt-4": { "prompt": "0.03", diff --git a/api/core/model_providers/rules/baichuan.json b/api/core/model_providers/rules/baichuan.json index 237b0d24d2..70b847cd8a 100644 --- a/api/core/model_providers/rules/baichuan.json +++ b/api/core/model_providers/rules/baichuan.json @@ -4,6 +4,9 @@ ], "system_config": null, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ], "price_config": { "baichuan2-53b": { "prompt": "0.01", diff --git a/api/core/model_providers/rules/chatglm.json b/api/core/model_providers/rules/chatglm.json index 0af3e61ec7..3ddfb8cf53 100644 --- a/api/core/model_providers/rules/chatglm.json +++ b/api/core/model_providers/rules/chatglm.json @@ -3,5 +3,8 @@ "custom" ], "system_config": null, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/cohere.json b/api/core/model_providers/rules/cohere.json index 0af3e61ec7..5ce0c9cc5b 100644 --- a/api/core/model_providers/rules/cohere.json +++ b/api/core/model_providers/rules/cohere.json @@ -3,5 +3,8 @@ "custom" ], "system_config": null, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "supported_model_types": [ + "reranking" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/huggingface_hub.json b/api/core/model_providers/rules/huggingface_hub.json index 5badb07178..3f1ee225f1 100644 --- a/api/core/model_providers/rules/huggingface_hub.json +++ b/api/core/model_providers/rules/huggingface_hub.json @@ -3,5 +3,9 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/localai.json b/api/core/model_providers/rules/localai.json index 5badb07178..3f1ee225f1 100644 --- a/api/core/model_providers/rules/localai.json +++ b/api/core/model_providers/rules/localai.json @@ -3,5 +3,9 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/minimax.json b/api/core/model_providers/rules/minimax.json index 765d6712e1..0348ec3dfb 100644 --- a/api/core/model_providers/rules/minimax.json +++ b/api/core/model_providers/rules/minimax.json @@ -10,6 +10,10 @@ "quota_unit": "tokens" }, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation", + "embeddings" + ], "price_config": { "abab5.5-chat": { "prompt": "0.015", diff --git a/api/core/model_providers/rules/openai.json b/api/core/model_providers/rules/openai.json index 17a3db72b8..4f1f39b792 100644 --- a/api/core/model_providers/rules/openai.json +++ b/api/core/model_providers/rules/openai.json @@ -11,6 +11,12 @@ "quota_limit": 200 }, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation", + "embeddings", + "speech2text", + "moderation" + ], "price_config": { "gpt-4": { "prompt": "0.03", diff --git a/api/core/model_providers/rules/openllm.json b/api/core/model_providers/rules/openllm.json index 5badb07178..3f1ee225f1 100644 --- a/api/core/model_providers/rules/openllm.json +++ b/api/core/model_providers/rules/openllm.json @@ -3,5 +3,9 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/replicate.json b/api/core/model_providers/rules/replicate.json index 5badb07178..3f1ee225f1 100644 --- a/api/core/model_providers/rules/replicate.json +++ b/api/core/model_providers/rules/replicate.json @@ -3,5 +3,9 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/spark.json b/api/core/model_providers/rules/spark.json index 24133107f8..4fa4d9a569 100644 --- a/api/core/model_providers/rules/spark.json +++ b/api/core/model_providers/rules/spark.json @@ -10,6 +10,9 @@ "quota_unit": "tokens" }, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ], "price_config": { "spark": { "prompt": "0.18", diff --git a/api/core/model_providers/rules/tongyi.json b/api/core/model_providers/rules/tongyi.json index c431f50b3f..319fbcaf9f 100644 --- a/api/core/model_providers/rules/tongyi.json +++ b/api/core/model_providers/rules/tongyi.json @@ -4,6 +4,9 @@ ], "system_config": null, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ], "price_config": { "qwen-turbo": { "prompt": "0.012", diff --git a/api/core/model_providers/rules/wenxin.json b/api/core/model_providers/rules/wenxin.json index dbb692fb42..193dccc411 100644 --- a/api/core/model_providers/rules/wenxin.json +++ b/api/core/model_providers/rules/wenxin.json @@ -4,6 +4,9 @@ ], "system_config": null, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation" + ], "price_config": { "ernie-bot-4": { "prompt": "0", diff --git a/api/core/model_providers/rules/xinference.json b/api/core/model_providers/rules/xinference.json index 5badb07178..3f1ee225f1 100644 --- a/api/core/model_providers/rules/xinference.json +++ b/api/core/model_providers/rules/xinference.json @@ -3,5 +3,9 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "supported_model_types": [ + "text-generation", + "embeddings" + ] } \ No newline at end of file diff --git a/api/core/model_providers/rules/zhipuai.json b/api/core/model_providers/rules/zhipuai.json index af0e5debba..07badcc313 100644 --- a/api/core/model_providers/rules/zhipuai.json +++ b/api/core/model_providers/rules/zhipuai.json @@ -10,6 +10,10 @@ "quota_unit": "tokens" }, "model_flexibility": "fixed", + "supported_model_types": [ + "text-generation", + "embeddings" + ], "price_config": { "chatglm_turbo": { "prompt": "0.005", diff --git a/api/services/provider_service.py b/api/services/provider_service.py index f9acedf8c2..4ba1b0fe91 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -17,11 +17,12 @@ from models.provider import Provider, ProviderModel, TenantPreferredModelProvide class ProviderService: - def get_provider_list(self, tenant_id: str): + def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list: """ get provider list of tenant. - :param tenant_id: + :param tenant_id: workspace id + :param model_type: filter by model type :return: """ # get rules for all providers @@ -79,6 +80,9 @@ class ProviderService: providers_list = {} for model_provider_name, model_provider_rule in model_provider_rules.items(): + if model_type and model_type not in model_provider_rule.get('supported_model_types', []): + continue + # get preferred provider type preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name) preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider( @@ -90,6 +94,7 @@ class ProviderService: provider_config_dict = { "preferred_provider_type": preferred_provider_type, "model_flexibility": model_provider_rule['model_flexibility'], + "supported_model_types": model_provider_rule.get("supported_model_types", []), } provider_parameter_dict = {}