From 18edeb8e0a65003c8aa90bfe1cb9971d245a90db Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 30 Oct 2024 18:56:52 -0700 Subject: [PATCH] integrate model provider with plugin daemon --- .../workspace/load_balancing_config.py | 4 ++-- .../console/workspace/model_providers.py | 18 +++++++-------- api/controllers/console/workspace/models.py | 14 ++++++------ api/core/entities/__init__.py | 1 + .../model_providers/__base/ai_model.py | 4 ++-- .../model_providers/model_provider_factory.py | 7 +++--- api/core/plugin/manager/base.py | 14 +++++++++--- .../entities/model_provider_entities.py | 22 ++++++++++++++----- api/services/model_provider_service.py | 6 ++++- 9 files changed, 58 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 771a866624..9d2697f11d 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -113,10 +113,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): # Load Balancing Config api.add_resource( LoadBalancingCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", ) api.add_resource( LoadBalancingConfigCredentialsValidateApi, - "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", ) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c50f507d4e..b9f13e3ce4 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,6 +1,6 @@ import io -from flask import request, send_file +from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -126,11 +126,7 @@ class ModelProviderIconApi(Resource): Get model provider icon """ - def get(self, provider: str, icon_type: str, lang: str): - tenant_id = request.args.get("tenant_id") - if not tenant_id: - return {"content": "Invalid request."}, 400 - + def get(self, tenant_id: str, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( tenant_id=tenant_id, @@ -139,6 +135,9 @@ class ModelProviderIconApi(Resource): lang=lang, ) + if not icon: + return {"message": "Icon not found"}, 404 + return send_file(io.BytesIO(icon), mimetype=mimetype) @@ -193,11 +192,12 @@ api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") -api.add_resource( - ModelProviderIconApi, "/workspaces/current/model-providers///" -) api.add_resource( PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" ) api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url") +api.add_resource( + ModelProviderIconApi, + "/workspaces//model-providers///", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 3138a260b3..7bbedc8828 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -320,7 +320,7 @@ class ModelProviderModelValidateApi(Resource): response = {"result": "success" if result else "error"} if not result: - response["error"] = error + response["error"] = error or "" return response @@ -357,26 +357,26 @@ class ModelProviderAvailableModelApi(Resource): return jsonable_encoder({"data": models}) -api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") api.add_resource( ModelProviderModelEnableApi, - "/workspaces/current/model-providers//models/enable", + "/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable", ) api.add_resource( ModelProviderModelDisableApi, - "/workspaces/current/model-providers//models/disable", + "/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable", ) api.add_resource( - ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" ) api.add_resource( - ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" ) api.add_resource( - ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" ) api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/core/entities/__init__.py b/api/core/entities/__init__.py index e69de29bb2..b848da3664 100644 --- a/api/core/entities/__init__.py +++ b/api/core/entities/__init__.py @@ -0,0 +1 @@ +DEFAULT_PLUGIN_ID = "langgenius" diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index a28d69ce80..bdbafc8ded 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,7 @@ import decimal from typing import Optional -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -15,7 +15,7 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.manager.model import PluginModelManager -class AIModel: +class AIModel(BaseModel): """ Base class for all models. """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 85a79dc0ce..c79c3a2b62 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,6 +5,7 @@ from typing import Optional from pydantic import BaseModel +from core.entities import DEFAULT_PLUGIN_ID from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity @@ -132,7 +133,7 @@ class ModelProviderFactory: tenant_id=self.tenant_id, user_id="unknown", plugin_id=plugin_model_provider_entity.plugin_id, - provider=provider, + provider=plugin_model_provider_entity.provider, credentials=filtered_credentials, ) @@ -167,7 +168,7 @@ class ModelProviderFactory: tenant_id=self.tenant_id, user_id="unknown", plugin_id=plugin_model_provider_entity.plugin_id, - provider=provider, + provider=plugin_model_provider_entity.provider, model_type=model_type.value, model=model, credentials=filtered_credentials, @@ -337,7 +338,7 @@ class ModelProviderFactory: :param provider: provider name :return: plugin id and provider name """ - plugin_id = "langgenius" + plugin_id = DEFAULT_PLUGIN_ID provider_name = provider if "/" in provider: # get the plugin_id before provider diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index 9980f7c15d..b25282cde2 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Callable, Generator from typing import Optional, TypeVar @@ -21,6 +22,8 @@ plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) +logger = logging.getLogger(__name__) + class BasePluginManager: def _request( @@ -44,9 +47,14 @@ class BasePluginManager: if headers.get("Content-Type") == "application/json" and isinstance(data, dict): data = json.dumps(data) - response = requests.request( - method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files - ) + try: + response = requests.request( + method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files + ) + except requests.exceptions.ConnectionError as e: + logger.exception(f"Request to Plugin Daemon Service failed: {e}") + raise ValueError("Request to Plugin Daemon Service failed") + return response def _stream_request( diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 4eed26efdf..7d0d442776 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -50,6 +50,7 @@ class ProviderResponse(BaseModel): Model class for provider response. """ + tenant_id: str provider: str label: I18nObject description: Optional[I18nObject] = None @@ -71,7 +72,9 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + url_prefix = ( + dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" + ) if self.icon_small is not None: self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" @@ -88,6 +91,7 @@ class ProviderWithModelsResponse(BaseModel): Model class for provider with models response. """ + tenant_id: str provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -98,7 +102,9 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + url_prefix = ( + dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" + ) if self.icon_small is not None: self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" @@ -115,10 +121,14 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): Simple provider entity response. """ + tenant_id: str + def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" + url_prefix = ( + dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" + ) if self.icon_small is not None: self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" @@ -150,5 +160,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): provider: SimpleProviderEntityResponse - def __init__(self, model: ModelWithProviderEntity) -> None: - super().__init__(**model.model_dump()) + def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: + dump_model = model.model_dump() + dump_model["provider"]["tenant_id"] = tenant_id + super().__init__(**dump_model) diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0375974041..589af9d87e 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -47,6 +47,7 @@ class ModelProviderService: continue provider_response = ProviderResponse( + tenant_id=tenant_id, provider=provider_configuration.provider.provider, label=provider_configuration.provider.label, description=provider_configuration.provider.description, @@ -90,7 +91,8 @@ class ModelProviderService: # Get provider available models return [ - ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model) + for model in provider_configurations.get_models(provider=provider) ] def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: @@ -303,6 +305,7 @@ class ModelProviderService: providers_with_models.append( ProviderWithModelsResponse( + tenant_id=tenant_id, provider=provider, label=first_model.provider.label, icon_small=first_model.provider.icon_small, @@ -373,6 +376,7 @@ class ModelProviderService: model=result.model, model_type=result.model_type, provider=SimpleProviderEntityResponse( + tenant_id=tenant_id, provider=result.provider.provider, label=result.provider.label, icon_small=result.provider.icon_small,