diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ad035e80e3..0869a29add 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,5 +1,6 @@ import io +import validators from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse @@ -631,6 +632,8 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") args = parser.parse_args() user = current_user + if not validators.url(args["server_url"]): + raise ValueError("Server URL is not valid.") return jsonable_encoder( MCPToolManageService.create_mcp_provider( tenant_id=user.current_tenant_id, @@ -655,6 +658,8 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + if not validators.url(args["server_url"]): + raise ValueError("Server URL is not valid.") MCPToolManageService.update_mcp_provider( tenant_id=current_user.current_tenant_id, name=args["name"], @@ -691,9 +696,10 @@ class ToolMCPAuthApi(Resource): provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) if not provider: raise ValueError("provider not found") + server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id) try: with MCPClient( - provider.server_url, + server_url, provider_id, tenant_id, authed=False, @@ -702,14 +708,14 @@ class ToolMCPAuthApi(Resource): MCPToolManageService.update_mcp_provider_credentials( tenant_id=tenant_id, provider_id=provider_id, - credentials={}, + credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id), authed=True, ) return {"result": "success"} except MCPAuthError: auth_provider = OAuthClientProvider(provider_id, tenant_id) - return auth(auth_provider, provider.server_url, args["authorization_code"]) + return auth(auth_provider, server_url, args["authorization_code"]) class ToolMCPDetailApi(Resource): @@ -761,6 +767,9 @@ class ToolMCPTokenApi(Resource): parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args") parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args") args = parser.parse_args() + server_url = MCPToolManageService.get_mcp_provider_server_url( + current_user.current_tenant_id, args["provider_id"] + ) provider = MCPToolManageService.get_mcp_provider_by_provider_id( args["provider_id"], current_user.current_tenant_id ) @@ -768,7 +777,7 @@ class ToolMCPTokenApi(Resource): raise ValueError("provider not found") return auth( OAuthClientProvider(args["provider_id"], current_user.current_tenant_id), - provider.server_url, + server_url, authorization_code=args["authorization_code"], ) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 3d14724845..5c7d9e4333 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -42,7 +42,9 @@ class OAuthClientProvider: mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) if not mcp_provider: return None - client_information = mcp_provider.credentials.get("client_information", {}) + client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials( + self.tenant_id, self.provider_id + ).get("client_information", {}) if not client_information: return None return OAuthClientInformation.model_validate(client_information) @@ -58,13 +60,13 @@ class OAuthClientProvider: mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) if not mcp_provider: return None - credentials = mcp_provider.credentials + credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) if not credentials: return None return OAuthTokens( access_token=credentials.get("access_token", ""), token_type=credentials.get("token_type", "Bearer"), - expires_in=credentials.get("expires_in", 3600), + expires_in=int(credentials.get("expires_in", "3600")), refresh_token=credentials.get("refresh_token", ""), ) @@ -87,4 +89,5 @@ class OAuthClientProvider: mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id) if not mcp_provider: return "" - return mcp_provider.credentials.get("code_verifier", "") + credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) + return credentials.get("code_verifier", "") diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 0bc588e9a8..bdb93e6034 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -40,6 +40,8 @@ class MCPToolProviderController(ToolProviderController): @classmethod def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": + from services.tools.mcp_tools_mange_service import MCPToolManageService + """ from db provider """ @@ -84,7 +86,7 @@ class MCPToolProviderController(ToolProviderController): ), provider_id=db_provider.id or "", tenant_id=db_provider.tenant_id or "", - server_url=db_provider.server_url, + server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id), ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 6a5fba65bd..251fedf56e 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,20 +72,21 @@ class ProviderConfigEncrypter(BaseModel): return data - def decrypt(self, data: dict[str, str]) -> dict[str, str]: + def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]: """ decrypt tool credentials with tenant id return a deep copy of credentials with decrypted values """ - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() - if cached_credentials: - return cached_credentials + if use_cache: + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f"{self.provider_type}.{self.provider_identity}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, + ) + cached_credentials = cache.get() + if cached_credentials: + return cached_credentials data = self._deep_copy(data) # get fields need to be decrypted fields = dict[str, BasicProviderConfig]() @@ -104,7 +105,8 @@ class ProviderConfigEncrypter(BaseModel): except Exception: pass - cache.set(data) + if use_cache: + cache.set(data) return data def delete_tool_credentials_cache(self): diff --git a/api/migrations/versions/2025_05_23_1623-ca4c4abcc347_add_app_mcp_server.py b/api/migrations/versions/2025_05_23_1623-ca4c4abcc347_add_app_mcp_server.py deleted file mode 100644 index cb3508afed..0000000000 --- a/api/migrations/versions/2025_05_23_1623-ca4c4abcc347_add_app_mcp_server.py +++ /dev/null @@ -1,41 +0,0 @@ -"""add app mcp server - -Revision ID: ca4c4abcc347 -Revises: 1e67f2654a08 -Create Date: 2025-05-22 16:23:44.206102 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'ca4c4abcc347' -down_revision = '1e67f2654a08' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('app_mcp_servers', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.types.StringUUID(), nullable=False), - sa.Column('app_id', models.types.StringUUID(), nullable=False), - sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('description', sa.String(length=255), nullable=False), - sa.Column('server_code', sa.String(length=255), nullable=False), - sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), - sa.Column('parameters', sa.Text(), nullable=False), - sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), - sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), - sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey') - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('app_mcp_servers') - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_05_23_1740-1e67f2654a08_add_mcp_provider.py b/api/migrations/versions/2025_05_29_1707-de71f8771550_add_app_mcp_client_and_server.py similarity index 50% rename from api/migrations/versions/2025_05_23_1740-1e67f2654a08_add_mcp_provider.py rename to api/migrations/versions/2025_05_29_1707-de71f8771550_add_app_mcp_client_and_server.py index a20e0561b8..bb669474e0 100644 --- a/api/migrations/versions/2025_05_23_1740-1e67f2654a08_add_mcp_provider.py +++ b/api/migrations/versions/2025_05_29_1707-de71f8771550_add_app_mcp_client_and_server.py @@ -1,8 +1,8 @@ -"""add mcp provider +"""add app mcp client and server -Revision ID: 1e67f2654a08 -Revises: 6a9f914f656c -Create Date: 2025-05-07 17:40:58.448440 +Revision ID: de71f8771550 +Revises: 2adcbe1f5dfb +Create Date: 2025-05-29 17:07:40.037945 """ from alembic import op @@ -11,7 +11,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '1e67f2654a08' +revision = 'de71f8771550' down_revision = 'b35c3db83d09' branch_labels = None depends_on = None @@ -19,10 +19,24 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_mcp_servers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('server_code', sa.String(length=255), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('parameters', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey') + ) op.create_table('tool_mcp_providers', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('name', sa.String(length=40), nullable=False), - sa.Column('server_url', sa.String(length=255), nullable=False), + sa.Column('server_url', sa.String(length=512), nullable=False), + sa.Column('server_url_hash', sa.String(length=64), nullable=False), sa.Column('icon', sa.String(length=255), nullable=True), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('user_id', models.types.StringUUID(), nullable=False), @@ -32,12 +46,16 @@ def upgrade(): sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'), - sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider') + sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider'), + sa.UniqueConstraint('server_url_hash', name='unique_mcp_tool_provider_server_url_hash') ) + # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tool_mcp_providers') + op.drop_table('app_mcp_servers') + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index d9cdba41ca..8a8d254227 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -233,14 +233,16 @@ class MCPToolProvider(Base): __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"), - db.UniqueConstraint("server_url", name="unique_mcp_tool_provider_server_url"), + db.UniqueConstraint("server_url_hash", name="unique_mcp_tool_provider_server_url_hash"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the mcp provider name: Mapped[str] = mapped_column(db.String(40), nullable=False) - # url of the mcp provider - server_url: Mapped[str] = mapped_column(db.String(255), nullable=False) + # encrypted url of the mcp provider + server_url: Mapped[str] = mapped_column(db.String(512), nullable=False) + # hash of server_url for uniqueness check + server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) # icon of the mcp provider icon: Mapped[str] = mapped_column(db.String(255), nullable=True) # tenant id diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 1434236f03..578cfea95e 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -1,17 +1,35 @@ +import hashlib import json +from urllib.parse import urlparse from sqlalchemy import or_ +from core.helper import encrypter from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService +def mask_url(url: str, mask_char: str = "*"): + """ + mask the url to a simple string + """ + parsed = urlparse(url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + if parsed.path and parsed.path != "/": + return f"{base_url}/{mask_char * 6}" + else: + return base_url + + class MCPToolManageService: """ Service class for managing mcp tools. @@ -32,11 +50,15 @@ class MCPToolManageService: def create_mcp_provider( tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str ) -> ToolProviderApiEntity: + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() existing_provider = ( db.session.query(MCPToolProvider) .filter( MCPToolProvider.tenant_id == tenant_id, - or_(MCPToolProvider.name == name, MCPToolProvider.server_url == server_url), + or_( + MCPToolProvider.name == name, + MCPToolProvider.server_url_hash == server_url_hash, + ), MCPToolProvider.tenant_id == tenant_id, ) .first() @@ -46,11 +68,12 @@ class MCPToolManageService: raise ValueError(f"MCP tool {name} already exists") else: raise ValueError(f"MCP tool {server_url} already exists") - + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_tool = MCPToolProvider( tenant_id=tenant_id, name=name, - server_url=server_url, + server_url=encrypted_server_url, + server_url_hash=server_url_hash, user_id=user_id, authed=False, tools="[]", @@ -68,10 +91,11 @@ class MCPToolManageService: @classmethod def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id) if mcp_provider is None: raise ValueError("MCP tool not found") try: - with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: + with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: tools = mcp_client.list_tools() except MCPAuthError as e: raise ValueError("Please auth the tool first") @@ -87,7 +111,7 @@ class MCPToolManageService: type=ToolProviderType.MCP, icon=mcp_provider.icon, author=mcp_provider.user.name if mcp_provider.user else "Anonymous", - server_url=mcp_provider.server_url, + server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id), updated_at=int(mcp_provider.updated_at.timestamp()), description=I18nObject(en_US="", zh_Hans=""), label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), @@ -107,7 +131,6 @@ class MCPToolManageService: raise ValueError("MCP tool not found") db.session.delete(mcp_tool) db.session.commit() - return {"result": "success"} @classmethod def update_mcp_provider( @@ -123,27 +146,54 @@ class MCPToolManageService: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: raise ValueError("MCP tool not found") + encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_provider.name = name - mcp_provider.server_url = server_url + mcp_provider.server_url = encrypted_server_url + mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() mcp_provider.icon = ( json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon ) db.session.commit() - return {"result": "success"} @classmethod def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: raise ValueError("MCP tool not found") + provider_controller = MCPToolProviderController._from_db(mcp_provider) + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=list(provider_controller.get_credentials_schema()), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.provider_id, + ) + credentials = tool_configuration.encrypt(credentials) mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials}) mcp_provider.authed = authed db.session.commit() - return {"result": "success"} @classmethod - def get_mcp_token(cls, provider_id: str, tenant_id: str): + def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: - raise ValueError("MCP provider not found") - return mcp_provider.credentials.get("access_token", None) + raise ValueError("MCP tool not found") + provider_controller = MCPToolProviderController._from_db(mcp_provider) + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=list(provider_controller.get_credentials_schema()), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.provider_id, + ) + return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False) + + @classmethod + def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str): + mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + if mcp_provider is None: + raise ValueError("MCP tool not found") + return encrypter.decrypt_token(tenant_id, mcp_provider.server_url) + + @classmethod + def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str): + server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id) + return mask_url(server_url) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 54f65d02a7..06e042f139 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -190,6 +190,8 @@ class ToolTransformService: @staticmethod def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity: + from services.tools.mcp_tools_mange_service import MCPToolManageService + return ToolProviderApiEntity( id=db_provider.id, author=db_provider.user.name if db_provider.user else "Anonymous", @@ -197,7 +199,7 @@ class ToolTransformService: icon=db_provider.provider_icon, type=ToolProviderType.MCP, is_team_authorization=db_provider.authed, - server_url=db_provider.server_url, + server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id), tools=ToolTransformService.mcp_tool_to_user_tool( db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] ),