mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 09:25:58 +08:00
Merge branch 'feat/mcp' into deploy/dev
This commit is contained in:
commit
ec6c0e52aa
@ -1,5 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
|
import validators
|
||||||
from flask import send_file
|
from flask import send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
@ -631,6 +632,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user = current_user
|
||||||
|
if not validators.url(args["server_url"]):
|
||||||
|
raise ValueError("Server URL is not valid.")
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
MCPToolManageService.create_mcp_provider(
|
MCPToolManageService.create_mcp_provider(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=user.current_tenant_id,
|
||||||
@ -655,6 +658,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if not validators.url(args["server_url"]):
|
||||||
|
raise ValueError("Server URL is not valid.")
|
||||||
MCPToolManageService.update_mcp_provider(
|
MCPToolManageService.update_mcp_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -691,9 +696,10 @@ class ToolMCPAuthApi(Resource):
|
|||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if not provider:
|
if not provider:
|
||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
|
server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||||
try:
|
try:
|
||||||
with MCPClient(
|
with MCPClient(
|
||||||
provider.server_url,
|
server_url,
|
||||||
provider_id,
|
provider_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
authed=False,
|
authed=False,
|
||||||
@ -702,14 +708,14 @@ class ToolMCPAuthApi(Resource):
|
|||||||
MCPToolManageService.update_mcp_provider_credentials(
|
MCPToolManageService.update_mcp_provider_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
credentials={},
|
credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id),
|
||||||
authed=True,
|
authed=True,
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
except MCPAuthError:
|
except MCPAuthError:
|
||||||
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
||||||
return auth(auth_provider, provider.server_url, args["authorization_code"])
|
return auth(auth_provider, server_url, args["authorization_code"])
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPDetailApi(Resource):
|
class ToolMCPDetailApi(Resource):
|
||||||
@ -761,6 +767,9 @@ class ToolMCPTokenApi(Resource):
|
|||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
|
||||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
|
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
server_url = MCPToolManageService.get_mcp_provider_server_url(
|
||||||
|
current_user.current_tenant_id, args["provider_id"]
|
||||||
|
)
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(
|
||||||
args["provider_id"], current_user.current_tenant_id
|
args["provider_id"], current_user.current_tenant_id
|
||||||
)
|
)
|
||||||
@ -768,7 +777,7 @@ class ToolMCPTokenApi(Resource):
|
|||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
return auth(
|
return auth(
|
||||||
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
|
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
|
||||||
provider.server_url,
|
server_url,
|
||||||
authorization_code=args["authorization_code"],
|
authorization_code=args["authorization_code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +42,9 @@ class OAuthClientProvider:
|
|||||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||||
if not mcp_provider:
|
if not mcp_provider:
|
||||||
return None
|
return None
|
||||||
client_information = mcp_provider.credentials.get("client_information", {})
|
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
|
||||||
|
self.tenant_id, self.provider_id
|
||||||
|
).get("client_information", {})
|
||||||
if not client_information:
|
if not client_information:
|
||||||
return None
|
return None
|
||||||
return OAuthClientInformation.model_validate(client_information)
|
return OAuthClientInformation.model_validate(client_information)
|
||||||
@ -58,13 +60,13 @@ class OAuthClientProvider:
|
|||||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||||
if not mcp_provider:
|
if not mcp_provider:
|
||||||
return None
|
return None
|
||||||
credentials = mcp_provider.credentials
|
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||||
if not credentials:
|
if not credentials:
|
||||||
return None
|
return None
|
||||||
return OAuthTokens(
|
return OAuthTokens(
|
||||||
access_token=credentials.get("access_token", ""),
|
access_token=credentials.get("access_token", ""),
|
||||||
token_type=credentials.get("token_type", "Bearer"),
|
token_type=credentials.get("token_type", "Bearer"),
|
||||||
expires_in=credentials.get("expires_in", 3600),
|
expires_in=int(credentials.get("expires_in", "3600")),
|
||||||
refresh_token=credentials.get("refresh_token", ""),
|
refresh_token=credentials.get("refresh_token", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,4 +89,5 @@ class OAuthClientProvider:
|
|||||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||||
if not mcp_provider:
|
if not mcp_provider:
|
||||||
return ""
|
return ""
|
||||||
return mcp_provider.credentials.get("code_verifier", "")
|
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||||
|
return credentials.get("code_verifier", "")
|
||||||
|
@ -40,6 +40,8 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||||
|
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from db provider
|
from db provider
|
||||||
"""
|
"""
|
||||||
@ -84,7 +86,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
),
|
),
|
||||||
provider_id=db_provider.id or "",
|
provider_id=db_provider.id or "",
|
||||||
tenant_id=db_provider.tenant_id or "",
|
tenant_id=db_provider.tenant_id or "",
|
||||||
server_url=db_provider.server_url,
|
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
@ -72,12 +72,13 @@ class ProviderConfigEncrypter(BaseModel):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
decrypt tool credentials with tenant id
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
return a deep copy of credentials with decrypted values
|
||||||
"""
|
"""
|
||||||
|
if use_cache:
|
||||||
cache = ToolProviderCredentialsCache(
|
cache = ToolProviderCredentialsCache(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||||
@ -104,6 +105,7 @@ class ProviderConfigEncrypter(BaseModel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
cache.set(data)
|
cache.set(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
"""add app mcp server
|
|
||||||
|
|
||||||
Revision ID: ca4c4abcc347
|
|
||||||
Revises: 1e67f2654a08
|
|
||||||
Create Date: 2025-05-22 16:23:44.206102
|
|
||||||
|
|
||||||
"""
|
|
||||||
from alembic import op
|
|
||||||
import models as models
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = 'ca4c4abcc347'
|
|
||||||
down_revision = '1e67f2654a08'
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('app_mcp_servers',
|
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
|
||||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('description', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
|
||||||
sa.Column('parameters', sa.Text(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('app_mcp_servers')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,8 +1,8 @@
|
|||||||
"""add mcp provider
|
"""add app mcp client and server
|
||||||
|
|
||||||
Revision ID: 1e67f2654a08
|
Revision ID: de71f8771550
|
||||||
Revises: 6a9f914f656c
|
Revises: 2adcbe1f5dfb
|
||||||
Create Date: 2025-05-07 17:40:58.448440
|
Create Date: 2025-05-29 17:07:40.037945
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from alembic import op
|
from alembic import op
|
||||||
@ -11,7 +11,7 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = '1e67f2654a08'
|
revision = 'de71f8771550'
|
||||||
down_revision = 'b35c3db83d09'
|
down_revision = 'b35c3db83d09'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
@ -19,10 +19,24 @@ depends_on = None
|
|||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('app_mcp_servers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||||
|
sa.Column('parameters', sa.Text(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
|
||||||
|
)
|
||||||
op.create_table('tool_mcp_providers',
|
op.create_table('tool_mcp_providers',
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
sa.Column('name', sa.String(length=40), nullable=False),
|
sa.Column('name', sa.String(length=40), nullable=False),
|
||||||
sa.Column('server_url', sa.String(length=255), nullable=False),
|
sa.Column('server_url', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
|
||||||
sa.Column('icon', sa.String(length=255), nullable=True),
|
sa.Column('icon', sa.String(length=255), nullable=True),
|
||||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||||
@ -32,12 +46,16 @@ def upgrade():
|
|||||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
||||||
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider')
|
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider'),
|
||||||
|
sa.UniqueConstraint('server_url_hash', name='unique_mcp_tool_provider_server_url_hash')
|
||||||
)
|
)
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table('tool_mcp_providers')
|
op.drop_table('tool_mcp_providers')
|
||||||
|
op.drop_table('app_mcp_servers')
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
@ -233,14 +233,16 @@ class MCPToolProvider(Base):
|
|||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||||
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
|
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
|
||||||
db.UniqueConstraint("server_url", name="unique_mcp_tool_provider_server_url"),
|
db.UniqueConstraint("server_url_hash", name="unique_mcp_tool_provider_server_url_hash"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
# name of the mcp provider
|
# name of the mcp provider
|
||||||
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||||
# url of the mcp provider
|
# encrypted url of the mcp provider
|
||||||
server_url: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
server_url: Mapped[str] = mapped_column(db.String(512), nullable=False)
|
||||||
|
# hash of server_url for uniqueness check
|
||||||
|
server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||||
# icon of the mcp provider
|
# icon of the mcp provider
|
||||||
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
||||||
# tenant id
|
# tenant id
|
||||||
|
@ -1,17 +1,35 @@
|
|||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import MCPToolProvider
|
from models.tools import MCPToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
|
||||||
|
def mask_url(url: str, mask_char: str = "*"):
|
||||||
|
"""
|
||||||
|
mask the url to a simple string
|
||||||
|
"""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
return f"{base_url}/{mask_char * 6}"
|
||||||
|
else:
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
|
||||||
class MCPToolManageService:
|
class MCPToolManageService:
|
||||||
"""
|
"""
|
||||||
Service class for managing mcp tools.
|
Service class for managing mcp tools.
|
||||||
@ -32,11 +50,15 @@ class MCPToolManageService:
|
|||||||
def create_mcp_provider(
|
def create_mcp_provider(
|
||||||
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
|
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
|
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
existing_provider = (
|
existing_provider = (
|
||||||
db.session.query(MCPToolProvider)
|
db.session.query(MCPToolProvider)
|
||||||
.filter(
|
.filter(
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
or_(MCPToolProvider.name == name, MCPToolProvider.server_url == server_url),
|
or_(
|
||||||
|
MCPToolProvider.name == name,
|
||||||
|
MCPToolProvider.server_url_hash == server_url_hash,
|
||||||
|
),
|
||||||
MCPToolProvider.tenant_id == tenant_id,
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
@ -46,11 +68,12 @@ class MCPToolManageService:
|
|||||||
raise ValueError(f"MCP tool {name} already exists")
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"MCP tool {server_url} already exists")
|
raise ValueError(f"MCP tool {server_url} already exists")
|
||||||
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
mcp_tool = MCPToolProvider(
|
mcp_tool = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=name,
|
name=name,
|
||||||
server_url=server_url,
|
server_url=encrypted_server_url,
|
||||||
|
server_url_hash=server_url_hash,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
authed=False,
|
authed=False,
|
||||||
tools="[]",
|
tools="[]",
|
||||||
@ -68,10 +91,11 @@ class MCPToolManageService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
|
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||||
if mcp_provider is None:
|
if mcp_provider is None:
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
try:
|
try:
|
||||||
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
|
with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
|
||||||
tools = mcp_client.list_tools()
|
tools = mcp_client.list_tools()
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
raise ValueError("Please auth the tool first")
|
raise ValueError("Please auth the tool first")
|
||||||
@ -87,7 +111,7 @@ class MCPToolManageService:
|
|||||||
type=ToolProviderType.MCP,
|
type=ToolProviderType.MCP,
|
||||||
icon=mcp_provider.icon,
|
icon=mcp_provider.icon,
|
||||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||||
server_url=mcp_provider.server_url,
|
server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id),
|
||||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||||
@ -107,7 +131,6 @@ class MCPToolManageService:
|
|||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
db.session.delete(mcp_tool)
|
db.session.delete(mcp_tool)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_mcp_provider(
|
def update_mcp_provider(
|
||||||
@ -123,27 +146,54 @@ class MCPToolManageService:
|
|||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if mcp_provider is None:
|
if mcp_provider is None:
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||||
mcp_provider.name = name
|
mcp_provider.name = name
|
||||||
mcp_provider.server_url = server_url
|
mcp_provider.server_url = encrypted_server_url
|
||||||
|
mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||||
mcp_provider.icon = (
|
mcp_provider.icon = (
|
||||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
|
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if mcp_provider is None:
|
if mcp_provider is None:
|
||||||
raise ValueError("MCP tool not found")
|
raise ValueError("MCP tool not found")
|
||||||
|
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||||
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
|
provider_type=provider_controller.provider_type.value,
|
||||||
|
provider_identity=provider_controller.provider_id,
|
||||||
|
)
|
||||||
|
credentials = tool_configuration.encrypt(credentials)
|
||||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
||||||
mcp_provider.authed = authed
|
mcp_provider.authed = authed
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return {"result": "success"}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mcp_token(cls, provider_id: str, tenant_id: str):
|
def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str):
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if mcp_provider is None:
|
if mcp_provider is None:
|
||||||
raise ValueError("MCP provider not found")
|
raise ValueError("MCP tool not found")
|
||||||
return mcp_provider.credentials.get("access_token", None)
|
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||||
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
|
provider_type=provider_controller.provider_type.value,
|
||||||
|
provider_identity=provider_controller.provider_id,
|
||||||
|
)
|
||||||
|
return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
|
||||||
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
|
if mcp_provider is None:
|
||||||
|
raise ValueError("MCP tool not found")
|
||||||
|
return encrypter.decrypt_token(tenant_id, mcp_provider.server_url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
|
||||||
|
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||||
|
return mask_url(server_url)
|
||||||
|
@ -190,6 +190,8 @@ class ToolTransformService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
|
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
|
||||||
|
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||||
|
|
||||||
return ToolProviderApiEntity(
|
return ToolProviderApiEntity(
|
||||||
id=db_provider.id,
|
id=db_provider.id,
|
||||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||||
@ -197,7 +199,7 @@ class ToolTransformService:
|
|||||||
icon=db_provider.provider_icon,
|
icon=db_provider.provider_icon,
|
||||||
type=ToolProviderType.MCP,
|
type=ToolProviderType.MCP,
|
||||||
is_team_authorization=db_provider.authed,
|
is_team_authorization=db_provider.authed,
|
||||||
server_url=db_provider.server_url,
|
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
|
||||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user