integrate model provider with plugin daemon

This commit is contained in:
takatost 2024-10-30 18:56:52 -07:00
parent 459cb9dd72
commit 18edeb8e0a
9 changed files with 58 additions and 32 deletions

View File

@ -113,10 +113,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@ -1,6 +1,6 @@
import io
from flask import request, send_file
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -126,11 +126,7 @@ class ModelProviderIconApi(Resource):
Get model provider icon
"""
def get(self, provider: str, icon_type: str, lang: str):
tenant_id = request.args.get("tenant_id")
if not tenant_id:
return {"content": "Invalid request."}, 400
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
tenant_id=tenant_id,
@ -139,6 +135,9 @@ class ModelProviderIconApi(Resource):
lang=lang,
)
if not icon:
return {"message": "Icon not found"}, 404
return send_file(io.BytesIO(icon), mimetype=mimetype)
@ -193,11 +192,12 @@ api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<path:provider>/<string:icon_type>/<string:lang>"
)
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@ -320,7 +320,7 @@ class ModelProviderModelValidateApi(Resource):
response = {"result": "success" if result else "error"}
if not result:
response["error"] = error
response["error"] = error or ""
return response
@ -357,26 +357,26 @@ class ModelProviderAvailableModelApi(Resource):
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<string:provider>/models/enable",
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<string:provider>/models/disable",
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -0,0 +1 @@
DEFAULT_PLUGIN_ID = "langgenius"

View File

@ -1,7 +1,7 @@
import decimal
from typing import Optional
from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.model_entities import (
AIModelEntity,
@ -15,7 +15,7 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager
class AIModel:
class AIModel(BaseModel):
"""
Base class for all models.
"""

View File

@ -5,6 +5,7 @@ from typing import Optional
from pydantic import BaseModel
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
@ -132,7 +133,7 @@ class ModelProviderFactory:
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=provider,
provider=plugin_model_provider_entity.provider,
credentials=filtered_credentials,
)
@ -167,7 +168,7 @@ class ModelProviderFactory:
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_model_provider_entity.plugin_id,
provider=provider,
provider=plugin_model_provider_entity.provider,
model_type=model_type.value,
model=model,
credentials=filtered_credentials,
@ -337,7 +338,7 @@ class ModelProviderFactory:
:param provider: provider name
:return: plugin id and provider name
"""
plugin_id = "langgenius"
plugin_id = DEFAULT_PLUGIN_ID
provider_name = provider
if "/" in provider:
# get the plugin_id before provider

View File

@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Callable, Generator
from typing import Optional, TypeVar
@ -21,6 +22,8 @@ plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
logger = logging.getLogger(__name__)
class BasePluginManager:
def _request(
@ -44,9 +47,14 @@ class BasePluginManager:
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
try:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError as e:
logger.exception(f"Request to Plugin Daemon Service failed: {e}")
raise ValueError("Request to Plugin Daemon Service failed")
return response
def _stream_request(

View File

@ -50,6 +50,7 @@ class ProviderResponse(BaseModel):
Model class for provider response.
"""
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
@ -71,7 +72,9 @@ class ProviderResponse(BaseModel):
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@ -88,6 +91,7 @@ class ProviderWithModelsResponse(BaseModel):
Model class for provider with models response.
"""
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
@ -98,7 +102,9 @@ class ProviderWithModelsResponse(BaseModel):
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@ -115,10 +121,14 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
Simple provider entity response.
"""
tenant_id: str
def __init__(self, **data) -> None:
super().__init__(**data)
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
if self.icon_small is not None:
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
@ -150,5 +160,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.model_dump())
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@ -47,6 +47,7 @@ class ModelProviderService:
continue
provider_response = ProviderResponse(
tenant_id=tenant_id,
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
@ -90,7 +91,8 @@ class ModelProviderService:
# Get provider available models
return [
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]:
@ -303,6 +305,7 @@ class ModelProviderService:
providers_with_models.append(
ProviderWithModelsResponse(
tenant_id=tenant_id,
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
@ -373,6 +376,7 @@ class ModelProviderService:
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
tenant_id=tenant_id,
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,