Merge branch 'feat/mcp' into deploy/dev

This commit is contained in:
Novice 2025-05-30 09:23:05 +08:00
commit ec6c0e52aa
9 changed files with 130 additions and 83 deletions

View File

@ -1,5 +1,6 @@
import io
import validators
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -631,6 +632,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args()
user = current_user
if not validators.url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
@ -655,6 +658,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not validators.url(args["server_url"]):
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
name=args["name"],
@ -691,9 +696,10 @@ class ToolMCPAuthApi(Resource):
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id)
try:
with MCPClient(
provider.server_url,
server_url,
provider_id,
tenant_id,
authed=False,
@ -702,14 +708,14 @@ class ToolMCPAuthApi(Resource):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials={},
credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id),
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"])
return auth(auth_provider, server_url, args["authorization_code"])
class ToolMCPDetailApi(Resource):
@ -761,6 +767,9 @@ class ToolMCPTokenApi(Resource):
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
server_url = MCPToolManageService.get_mcp_provider_server_url(
current_user.current_tenant_id, args["provider_id"]
)
provider = MCPToolManageService.get_mcp_provider_by_provider_id(
args["provider_id"], current_user.current_tenant_id
)
@ -768,7 +777,7 @@ class ToolMCPTokenApi(Resource):
raise ValueError("provider not found")
return auth(
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
provider.server_url,
server_url,
authorization_code=args["authorization_code"],
)

View File

@ -42,7 +42,9 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
client_information = mcp_provider.credentials.get("client_information", {})
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
self.tenant_id, self.provider_id
).get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
@ -58,13 +60,13 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
credentials = mcp_provider.credentials
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=credentials.get("expires_in", 3600),
expires_in=int(credentials.get("expires_in", "3600")),
refresh_token=credentials.get("refresh_token", ""),
)
@ -87,4 +89,5 @@ class OAuthClientProvider:
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return ""
return mcp_provider.credentials.get("code_verifier", "")
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return credentials.get("code_verifier", "")

View File

@ -40,6 +40,8 @@ class MCPToolProviderController(ToolProviderController):
@classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
from services.tools.mcp_tools_mange_service import MCPToolManageService
"""
from db provider
"""
@ -84,7 +86,7 @@ class MCPToolProviderController(ToolProviderController):
),
provider_id=db_provider.id or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.server_url,
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:

View File

@ -72,20 +72,21 @@ class ProviderConfigEncrypter(BaseModel):
return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
if use_cache:
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
@ -104,7 +105,8 @@ class ProviderConfigEncrypter(BaseModel):
except Exception:
pass
cache.set(data)
if use_cache:
cache.set(data)
return data
def delete_tool_credentials_cache(self):

View File

@ -1,41 +0,0 @@
"""add app mcp server
Revision ID: ca4c4abcc347
Revises: 1e67f2654a08
Create Date: 2025-05-22 16:23:44.206102
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ca4c4abcc347'
down_revision = '1e67f2654a08'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_mcp_servers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('server_code', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
sa.Column('parameters', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('app_mcp_servers')
# ### end Alembic commands ###

View File

@ -1,8 +1,8 @@
"""add mcp provider
"""add app mcp client and server
Revision ID: 1e67f2654a08
Revises: 6a9f914f656c
Create Date: 2025-05-07 17:40:58.448440
Revision ID: de71f8771550
Revises: 2adcbe1f5dfb
Create Date: 2025-05-29 17:07:40.037945
"""
from alembic import op
@ -11,7 +11,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1e67f2654a08'
revision = 'de71f8771550'
down_revision = 'b35c3db83d09'
branch_labels = None
depends_on = None
@ -19,10 +19,24 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_mcp_servers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('server_code', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
sa.Column('parameters', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
)
op.create_table('tool_mcp_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('name', sa.String(length=40), nullable=False),
sa.Column('server_url', sa.String(length=255), nullable=False),
sa.Column('server_url', sa.String(length=512), nullable=False),
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
sa.Column('icon', sa.String(length=255), nullable=True),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('user_id', models.types.StringUUID(), nullable=False),
@ -32,12 +46,16 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider')
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider'),
sa.UniqueConstraint('server_url_hash', name='unique_mcp_tool_provider_server_url_hash')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_mcp_providers')
op.drop_table('app_mcp_servers')
# ### end Alembic commands ###

View File

@ -233,14 +233,16 @@ class MCPToolProvider(Base):
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
db.UniqueConstraint("server_url", name="unique_mcp_tool_provider_server_url"),
db.UniqueConstraint("server_url_hash", name="unique_mcp_tool_provider_server_url_hash"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# url of the mcp provider
server_url: Mapped[str] = mapped_column(db.String(255), nullable=False)
# encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.String(512), nullable=False)
# hash of server_url for uniqueness check
server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
# icon of the mcp provider
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
# tenant id

View File

@ -1,17 +1,35 @@
import hashlib
import json
from urllib.parse import urlparse
from sqlalchemy import or_
from core.helper import encrypter
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
def mask_url(url: str, mask_char: str = "*"):
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
class MCPToolManageService:
"""
Service class for managing mcp tools.
@ -32,11 +50,15 @@ class MCPToolManageService:
def create_mcp_provider(
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.filter(
MCPToolProvider.tenant_id == tenant_id,
or_(MCPToolProvider.name == name, MCPToolProvider.server_url == server_url),
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
@ -46,11 +68,12 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {name} already exists")
else:
raise ValueError(f"MCP tool {server_url} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
server_url=server_url,
server_url=encrypted_server_url,
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools="[]",
@ -68,10 +91,11 @@ class MCPToolManageService:
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
try:
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
@ -87,7 +111,7 @@ class MCPToolManageService:
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
server_url=mcp_provider.server_url,
server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id),
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
@ -107,7 +131,6 @@ class MCPToolManageService:
raise ValueError("MCP tool not found")
db.session.delete(mcp_tool)
db.session.commit()
return {"result": "success"}
@classmethod
def update_mcp_provider(
@ -123,27 +146,54 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.name = name
mcp_provider.server_url = server_url
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
db.session.commit()
return {"result": "success"}
@classmethod
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
db.session.commit()
return {"result": "success"}
@classmethod
def get_mcp_token(cls, provider_id: str, tenant_id: str):
def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP provider not found")
return mcp_provider.credentials.get("access_token", None)
raise ValueError("MCP tool not found")
provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False)
@classmethod
def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
return encrypter.decrypt_token(tenant_id, mcp_provider.server_url)
@classmethod
def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
return mask_url(server_url)

View File

@ -190,6 +190,8 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
from services.tools.mcp_tools_mange_service import MCPToolManageService
return ToolProviderApiEntity(
id=db_provider.id,
author=db_provider.user.name if db_provider.user else "Anonymous",
@ -197,7 +199,7 @@ class ToolTransformService:
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.server_url,
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
),