Merge branch 'feat/mcp' into deploy/dev

This commit is contained in:
Novice 2025-05-30 09:23:05 +08:00
commit ec6c0e52aa
9 changed files with 130 additions and 83 deletions

View File

@ -1,5 +1,6 @@
import io import io
import validators
from flask import send_file from flask import send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@ -631,6 +632,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
if not validators.url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder( return jsonable_encoder(
MCPToolManageService.create_mcp_provider( MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id, tenant_id=user.current_tenant_id,
@ -655,6 +658,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not validators.url(args["server_url"]):
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider( MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
name=args["name"], name=args["name"],
@ -691,9 +696,10 @@ class ToolMCPAuthApi(Resource):
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider: if not provider:
raise ValueError("provider not found") raise ValueError("provider not found")
server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id)
try: try:
with MCPClient( with MCPClient(
provider.server_url, server_url,
provider_id, provider_id,
tenant_id, tenant_id,
authed=False, authed=False,
@ -702,14 +708,14 @@ class ToolMCPAuthApi(Resource):
MCPToolManageService.update_mcp_provider_credentials( MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
credentials={}, credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id),
authed=True, authed=True,
) )
return {"result": "success"} return {"result": "success"}
except MCPAuthError: except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id) auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"]) return auth(auth_provider, server_url, args["authorization_code"])
class ToolMCPDetailApi(Resource): class ToolMCPDetailApi(Resource):
@ -761,6 +767,9 @@ class ToolMCPTokenApi(Resource):
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args") parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
args = parser.parse_args() args = parser.parse_args()
server_url = MCPToolManageService.get_mcp_provider_server_url(
current_user.current_tenant_id, args["provider_id"]
)
provider = MCPToolManageService.get_mcp_provider_by_provider_id( provider = MCPToolManageService.get_mcp_provider_by_provider_id(
args["provider_id"], current_user.current_tenant_id args["provider_id"], current_user.current_tenant_id
) )
@ -768,7 +777,7 @@ class ToolMCPTokenApi(Resource):
raise ValueError("provider not found") raise ValueError("provider not found")
return auth( return auth(
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id), OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
provider.server_url, server_url,
authorization_code=args["authorization_code"], authorization_code=args["authorization_code"],
) )

View File

@ -42,7 +42,9 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider: if not mcp_provider:
return None return None
client_information = mcp_provider.credentials.get("client_information", {}) client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
self.tenant_id, self.provider_id
).get("client_information", {})
if not client_information: if not client_information:
return None return None
return OAuthClientInformation.model_validate(client_information) return OAuthClientInformation.model_validate(client_information)
@ -58,13 +60,13 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider: if not mcp_provider:
return None return None
credentials = mcp_provider.credentials credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
if not credentials: if not credentials:
return None return None
return OAuthTokens( return OAuthTokens(
access_token=credentials.get("access_token", ""), access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"), token_type=credentials.get("token_type", "Bearer"),
expires_in=credentials.get("expires_in", 3600), expires_in=int(credentials.get("expires_in", "3600")),
refresh_token=credentials.get("refresh_token", ""), refresh_token=credentials.get("refresh_token", ""),
) )
@ -87,4 +89,5 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider: if not mcp_provider:
return "" return ""
return mcp_provider.credentials.get("code_verifier", "") credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return credentials.get("code_verifier", "")

View File

@ -40,6 +40,8 @@ class MCPToolProviderController(ToolProviderController):
@classmethod @classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
from services.tools.mcp_tools_mange_service import MCPToolManageService
""" """
from db provider from db provider
""" """
@ -84,7 +86,7 @@ class MCPToolProviderController(ToolProviderController):
), ),
provider_id=db_provider.id or "", provider_id=db_provider.id or "",
tenant_id=db_provider.tenant_id or "", tenant_id=db_provider.tenant_id or "",
server_url=db_provider.server_url, server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
) )
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:

View File

@ -72,20 +72,21 @@ class ProviderConfigEncrypter(BaseModel):
return data return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]: def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
""" """
decrypt tool credentials with tenant id decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values return a deep copy of credentials with decrypted values
""" """
cache = ToolProviderCredentialsCache( if use_cache:
tenant_id=self.tenant_id, cache = ToolProviderCredentialsCache(
identity_id=f"{self.provider_type}.{self.provider_identity}", tenant_id=self.tenant_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER, identity_id=f"{self.provider_type}.{self.provider_identity}",
) cache_type=ToolProviderCredentialsCacheType.PROVIDER,
cached_credentials = cache.get() )
if cached_credentials: cached_credentials = cache.get()
return cached_credentials if cached_credentials:
return cached_credentials
data = self._deep_copy(data) data = self._deep_copy(data)
# get fields need to be decrypted # get fields need to be decrypted
fields = dict[str, BasicProviderConfig]() fields = dict[str, BasicProviderConfig]()
@ -104,7 +105,8 @@ class ProviderConfigEncrypter(BaseModel):
except Exception: except Exception:
pass pass
cache.set(data) if use_cache:
cache.set(data)
return data return data
def delete_tool_credentials_cache(self): def delete_tool_credentials_cache(self):

View File

@ -1,41 +0,0 @@
"""add app mcp server
Revision ID: ca4c4abcc347
Revises: 1e67f2654a08
Create Date: 2025-05-22 16:23:44.206102
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ca4c4abcc347'
down_revision = '1e67f2654a08'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_mcp_servers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('server_code', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
sa.Column('parameters', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('app_mcp_servers')
# ### end Alembic commands ###

View File

@ -1,8 +1,8 @@
"""add mcp provider """add app mcp client and server
Revision ID: 1e67f2654a08 Revision ID: de71f8771550
Revises: 6a9f914f656c Revises: 2adcbe1f5dfb
Create Date: 2025-05-07 17:40:58.448440 Create Date: 2025-05-29 17:07:40.037945
""" """
from alembic import op from alembic import op
@ -11,7 +11,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1e67f2654a08' revision = 'de71f8771550'
down_revision = 'b35c3db83d09' down_revision = 'b35c3db83d09'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -19,10 +19,24 @@ depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_mcp_servers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('server_code', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
sa.Column('parameters', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
)
op.create_table('tool_mcp_providers', op.create_table('tool_mcp_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('name', sa.String(length=40), nullable=False), sa.Column('name', sa.String(length=40), nullable=False),
sa.Column('server_url', sa.String(length=255), nullable=False), sa.Column('server_url', sa.String(length=512), nullable=False),
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
sa.Column('icon', sa.String(length=255), nullable=True), sa.Column('icon', sa.String(length=255), nullable=True),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('user_id', models.types.StringUUID(), nullable=False), sa.Column('user_id', models.types.StringUUID(), nullable=False),
@ -32,12 +46,16 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'), sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider') sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider'),
sa.UniqueConstraint('server_url_hash', name='unique_mcp_tool_provider_server_url_hash')
) )
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_mcp_providers') op.drop_table('tool_mcp_providers')
op.drop_table('app_mcp_servers')
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -233,14 +233,16 @@ class MCPToolProvider(Base):
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"), db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
db.UniqueConstraint("server_url", name="unique_mcp_tool_provider_server_url"), db.UniqueConstraint("server_url_hash", name="unique_mcp_tool_provider_server_url_hash"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the mcp provider # name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False) name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# url of the mcp provider # encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.String(255), nullable=False) server_url: Mapped[str] = mapped_column(db.String(512), nullable=False)
# hash of server_url for uniqueness check
server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
# icon of the mcp provider # icon of the mcp provider
icon: Mapped[str] = mapped_column(db.String(255), nullable=True) icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
# tenant id # tenant id

View File

@ -1,17 +1,35 @@
import hashlib
import json import json
from urllib.parse import urlparse
from sqlalchemy import or_ from sqlalchemy import or_
from core.helper import encrypter
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import MCPToolProvider from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
def mask_url(url: str, mask_char: str = "*"):
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
class MCPToolManageService: class MCPToolManageService:
""" """
Service class for managing mcp tools. Service class for managing mcp tools.
@ -32,11 +50,15 @@ class MCPToolManageService:
def create_mcp_provider( def create_mcp_provider(
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
) -> ToolProviderApiEntity: ) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = ( existing_provider = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter( .filter(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.tenant_id == tenant_id,
or_(MCPToolProvider.name == name, MCPToolProvider.server_url == server_url), or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
),
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.tenant_id == tenant_id,
) )
.first() .first()
@ -46,11 +68,12 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {name} already exists") raise ValueError(f"MCP tool {name} already exists")
else: else:
raise ValueError(f"MCP tool {server_url} already exists") raise ValueError(f"MCP tool {server_url} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_tool = MCPToolProvider( mcp_tool = MCPToolProvider(
tenant_id=tenant_id, tenant_id=tenant_id,
name=name, name=name,
server_url=server_url, server_url=encrypted_server_url,
server_url_hash=server_url_hash,
user_id=user_id, user_id=user_id,
authed=False, authed=False,
tools="[]", tools="[]",
@ -68,10 +91,11 @@ class MCPToolManageService:
@classmethod @classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
try: try:
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools() tools = mcp_client.list_tools()
except MCPAuthError as e: except MCPAuthError as e:
raise ValueError("Please auth the tool first") raise ValueError("Please auth the tool first")
@ -87,7 +111,7 @@ class MCPToolManageService:
type=ToolProviderType.MCP, type=ToolProviderType.MCP,
icon=mcp_provider.icon, icon=mcp_provider.icon,
author=mcp_provider.user.name if mcp_provider.user else "Anonymous", author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
server_url=mcp_provider.server_url, server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id),
updated_at=int(mcp_provider.updated_at.timestamp()), updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""), description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
@ -107,7 +131,6 @@ class MCPToolManageService:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
db.session.delete(mcp_tool) db.session.delete(mcp_tool)
db.session.commit() db.session.commit()
return {"result": "success"}
@classmethod @classmethod
def update_mcp_provider( def update_mcp_provider(
@ -123,27 +146,54 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.name = name mcp_provider.name = name
mcp_provider.server_url = server_url mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
mcp_provider.icon = ( mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
) )
db.session.commit() db.session.commit()
return {"result": "success"}
@classmethod @classmethod
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False): def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed mcp_provider.authed = authed
db.session.commit() db.session.commit()
return {"result": "success"}
@classmethod @classmethod
def get_mcp_token(cls, provider_id: str, tenant_id: str): def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP provider not found") raise ValueError("MCP tool not found")
return mcp_provider.credentials.get("access_token", None) provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False)
@classmethod
def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
return encrypter.decrypt_token(tenant_id, mcp_provider.server_url)
@classmethod
def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
return mask_url(server_url)

View File

@ -190,6 +190,8 @@ class ToolTransformService:
@staticmethod @staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity: def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
from services.tools.mcp_tools_mange_service import MCPToolManageService
return ToolProviderApiEntity( return ToolProviderApiEntity(
id=db_provider.id, id=db_provider.id,
author=db_provider.user.name if db_provider.user else "Anonymous", author=db_provider.user.name if db_provider.user else "Anonymous",
@ -197,7 +199,7 @@ class ToolTransformService:
icon=db_provider.provider_icon, icon=db_provider.provider_icon,
type=ToolProviderType.MCP, type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed, is_team_authorization=db_provider.authed,
server_url=db_provider.server_url, server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
tools=ToolTransformService.mcp_tool_to_user_tool( tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
), ),