mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
feat: implement MCP specification 2025-06-18 (#25766)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,86 +1,118 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import or_
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
CLIENT_NAME = "Dify"
|
||||
EMPTY_TOOLS_JSON = "[]"
|
||||
EMPTY_CREDENTIALS_JSON = "{}"
|
||||
|
||||
|
||||
class OAuthDataType(StrEnum):
|
||||
"""Types of OAuth data that can be saved."""
|
||||
|
||||
TOKENS = "tokens"
|
||||
CLIENT_INFO = "client_info"
|
||||
CODE_VERIFIER = "code_verifier"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class ReconnectResult(BaseModel):
|
||||
"""Result of reconnecting to an MCP provider"""
|
||||
|
||||
authed: bool = Field(description="Whether the provider is authenticated")
|
||||
tools: str = Field(description="JSON string of tool list")
|
||||
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
|
||||
|
||||
|
||||
class ServerUrlValidationResult(BaseModel):
|
||||
"""Result of server URL validation check"""
|
||||
|
||||
needs_validation: bool
|
||||
validation_passed: bool = False
|
||||
reconnect_result: ReconnectResult | None = None
|
||||
encrypted_server_url: str | None = None
|
||||
server_url_hash: str | None = None
|
||||
|
||||
@property
|
||||
def should_update_server_url(self) -> bool:
|
||||
"""Check if server URL should be updated based on validation result"""
|
||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
# ========== Provider CRUD Operations ==========
|
||||
|
||||
def get_provider(
|
||||
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
|
||||
) -> MCPToolProvider:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
Get MCP provider by ID or server identifier.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of headers to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
provider_id: Provider ID (UUID)
|
||||
server_identifier: Server identifier
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Dictionary with all headers encrypted
|
||||
MCPToolProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
if server_identifier:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||
)
|
||||
else:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
|
||||
)
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter_instance.encrypt(headers)
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||
"""Get provider entity by ID or server identifier."""
|
||||
if by_server_id:
|
||||
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
else:
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return db_provider.to_entity()
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_provider(
|
||||
def create_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
@@ -89,37 +121,30 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
"""Create a new MCP provider."""
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
# Check for existing provider
|
||||
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
|
||||
|
||||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
encrypted_credentials = None
|
||||
if authentication is not None and authentication.client_id:
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
authentication.client_id, authentication.client_secret, tenant_id
|
||||
)
|
||||
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@@ -127,93 +152,23 @@ class MCPToolManageService:
|
||||
server_url_hash=server_url_hash,
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
tools=EMPTY_TOOLS_JSON,
|
||||
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
|
||||
@staticmethod
|
||||
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
mcp_providers = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
|
||||
for mcp_provider in mcp_providers
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = mcp_provider.decrypted_server_url
|
||||
authed = mcp_provider.authed
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=authed,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError:
|
||||
raise ValueError("Please auth the tool first")
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
try:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
user = mcp_provider.load_user()
|
||||
if not mcp_provider.icon:
|
||||
raise ValueError("MCP provider icon is required")
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=user.name if user else "Anonymous",
|
||||
server_url=mcp_provider.masked_server_url,
|
||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
db.session.delete(mcp_tool)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider(
|
||||
cls,
|
||||
def update_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
@@ -222,129 +177,546 @@ class MCPToolManageService:
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
validation_result: ServerUrlValidationResult | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update an MCP provider.
|
||||
|
||||
reconnect_result = None
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check for duplicate name (excluding current provider)
|
||||
if name != mcp_provider.name:
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.id != provider_id,
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
if existing_provider:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
|
||||
# Get URL update data from validation result
|
||||
encrypted_server_url = None
|
||||
server_url_hash = None
|
||||
reconnect_result = None
|
||||
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
if server_url_hash != mcp_provider.server_url_hash:
|
||||
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
|
||||
if validation_result and validation_result.encrypted_server_url:
|
||||
# Use all data from validation result
|
||||
encrypted_server_url = validation_result.encrypted_server_url
|
||||
server_url_hash = validation_result.server_url_hash
|
||||
reconnect_result = validation_result.reconnect_result
|
||||
|
||||
try:
|
||||
# Update basic fields
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = (
|
||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||
)
|
||||
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
if encrypted_server_url is not None and server_url_hash is not None:
|
||||
# Update server URL if changed
|
||||
if encrypted_server_url and server_url_hash:
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
|
||||
if reconnect_result:
|
||||
mcp_provider.authed = reconnect_result["authed"]
|
||||
mcp_provider.tools = reconnect_result["tools"]
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
mcp_provider.authed = reconnect_result.authed
|
||||
mcp_provider.tools = reconnect_result.tools
|
||||
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
|
||||
|
||||
if timeout is not None:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
# Update optional configuration fields
|
||||
self._update_optional_fields(mcp_provider, configuration)
|
||||
|
||||
# Update headers if provided
|
||||
if headers is not None:
|
||||
# Merge masked headers from frontend with existing real values
|
||||
if headers:
|
||||
# existing decrypted and masked headers
|
||||
existing_decrypted = mcp_provider.decrypted_headers
|
||||
existing_masked = mcp_provider.masked_headers
|
||||
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
|
||||
|
||||
# Build final headers: if value equals masked existing, keep original decrypted value
|
||||
final_headers: dict[str, str] = {}
|
||||
for key, incoming_value in headers.items():
|
||||
if (
|
||||
key in existing_masked
|
||||
and key in existing_decrypted
|
||||
and isinstance(incoming_value, str)
|
||||
and incoming_value == existing_masked.get(key)
|
||||
):
|
||||
# unchanged, use original decrypted value
|
||||
final_headers[key] = str(existing_decrypted[key])
|
||||
else:
|
||||
final_headers[key] = incoming_value
|
||||
# Update credentials if provided
|
||||
if authentication and authentication.client_id:
|
||||
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
# Explicitly clear headers if empty dict passed
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider."""
|
||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
for_list: If True, return provider ID; if False, return server identifier
|
||||
include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility)
|
||||
"""
|
||||
from models.account import Account
|
||||
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
mcp_providers = self._session.scalars(stmt).all()
|
||||
|
||||
if not mcp_providers:
|
||||
return []
|
||||
|
||||
# Batch query all users to avoid N+1 problem
|
||||
user_ids = {provider.user_id for provider in mcp_providers}
|
||||
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
|
||||
user_name_map = {user.id: user.name for user in users}
|
||||
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(
|
||||
provider,
|
||||
for_list=for_list,
|
||||
user_name=user_name_map.get(provider.user_id),
|
||||
include_sensitive=include_sensitive,
|
||||
)
|
||||
for provider in mcp_providers
|
||||
]
|
||||
|
||||
# ========== Tool Operations ==========
|
||||
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server."""
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
# Verify authentication
|
||||
if not provider_entity.authed:
|
||||
raise ValueError("Please auth the tool first")
|
||||
|
||||
# Prepare headers with auth token
|
||||
headers = self._prepare_auth_headers(provider_entity)
|
||||
|
||||
# Retrieve tools from remote server
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
|
||||
# Build API response
|
||||
return self._build_tool_provider_response(db_provider, provider_entity, tools)
|
||||
|
||||
# ========== OAuth and Credentials Operations ==========
|
||||
|
||||
def update_provider_credentials(
|
||||
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update provider credentials with encryption.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
credentials: Credentials to save
|
||||
authed: Whether provider is authenticated (None means keep current state)
|
||||
"""
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Encrypt new credentials
|
||||
provider_controller = MCPToolProviderController.from_db(provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
tenant_id=provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
||||
mcp_provider.authed = authed
|
||||
if not authed:
|
||||
mcp_provider.tools = "[]"
|
||||
db.session.commit()
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
|
||||
# Get the existing provider to access headers and timeout settings
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
# Update provider
|
||||
provider.updated_at = datetime.now()
|
||||
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||
|
||||
if authed is not None:
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
def save_oauth_data(
|
||||
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
|
||||
) -> None:
|
||||
"""
|
||||
Save OAuth-related data (tokens, client info, code verifier).
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
data: Data to save (tokens, client info, or code verifier)
|
||||
data_type: Type of OAuth data to save
|
||||
"""
|
||||
# Determine if this makes the provider authenticated
|
||||
authed = (
|
||||
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
|
||||
)
|
||||
|
||||
# update_provider_credentials will validate provider existence
|
||||
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
|
||||
|
||||
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Clear all credentials for a provider.
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
# Get provider from current session
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
|
||||
# ========== Private Helper Methods ==========
|
||||
|
||||
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
|
||||
"""Check if provider with same attributes already exists."""
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError("MCP tool with this server URL already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
|
||||
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
|
||||
"""Prepare icon data for storage."""
|
||||
if icon_type == "emoji":
|
||||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
|
||||
"""Encrypt specified fields in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing data to encrypt
|
||||
secret_fields: List of field names to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
JSON string of encrypted data
|
||||
"""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create config for secret fields
|
||||
config = [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
|
||||
]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
encrypted_data = encrypter_instance.encrypt(data)
|
||||
return encrypted_data
|
||||
|
||||
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
# All headers are treated as secret
|
||||
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
|
||||
|
||||
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||
"""Prepare headers with OAuth token if available."""
|
||||
headers = provider_entity.decrypt_headers()
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
return headers
|
||||
|
||||
def _retrieve_remote_mcp_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
):
|
||||
"""Retrieve tools from remote MCP server."""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.list_tools()
|
||||
|
||||
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
|
||||
"""
|
||||
Execute the actions returned by the auth function.
|
||||
|
||||
This method processes the AuthResult and performs the necessary database operations.
|
||||
|
||||
Args:
|
||||
auth_result: The result from the auth function
|
||||
|
||||
Returns:
|
||||
The response from the auth result
|
||||
"""
|
||||
from core.mcp.entities import AuthAction, AuthActionType
|
||||
|
||||
action: AuthAction
|
||||
for action in auth_result.actions:
|
||||
if action.provider_id is None or action.tenant_id is None:
|
||||
continue
|
||||
|
||||
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
|
||||
elif action.action_type == AuthActionType.SAVE_TOKENS:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
|
||||
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
|
||||
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
|
||||
|
||||
return auth_result.response
|
||||
|
||||
def auth_with_actions(
|
||||
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Perform authentication and execute all resulting actions.
|
||||
|
||||
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
|
||||
|
||||
Args:
|
||||
provider_entity: The MCP provider entity
|
||||
authorization_code: Optional authorization code
|
||||
|
||||
Returns:
|
||||
Response dictionary from auth result
|
||||
"""
|
||||
auth_result = auth(provider_entity, authorization_code)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
"authed": True,
|
||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
||||
"encrypted_credentials": "{}",
|
||||
}
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
"""
|
||||
# Handle hidden/unchanged URL
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
|
||||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
reconnect_result=reconnect_result,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Build API response for tool provider."""
|
||||
user = db_provider.load_user()
|
||||
response = provider_entity.to_api_response(
|
||||
user_name=user.name if user else None,
|
||||
)
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
def _handle_integrity_error(
|
||||
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
|
||||
) -> None:
|
||||
"""Handle database integrity errors with user-friendly messages."""
|
||||
error_msg = str(error.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
if "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
|
||||
"""Update optional configuration fields using setattr for cleaner code."""
|
||||
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
|
||||
|
||||
for field, value in field_mapping.items():
|
||||
if value is not None:
|
||||
setattr(mcp_provider, field, value)
|
||||
|
||||
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
|
||||
"""Process headers update, handling empty dict to clear headers."""
|
||||
if not headers:
|
||||
return None
|
||||
|
||||
# Merge with existing headers to preserve masked values
|
||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||
return self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||
|
||||
def _process_credentials(
|
||||
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
|
||||
) -> str:
|
||||
"""Process credentials update, handling masked values."""
|
||||
# Merge with existing credentials
|
||||
final_client_id, final_client_secret = self._merge_credentials_with_masked(
|
||||
authentication.client_id, authentication.client_secret, mcp_provider
|
||||
)
|
||||
|
||||
# Build and encrypt
|
||||
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
|
||||
|
||||
def _merge_headers_with_masked(
|
||||
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
|
||||
) -> dict[str, str]:
|
||||
"""Merge incoming headers with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
incoming_headers: Headers from frontend (may contain masked values)
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Final headers dict with proper values (original for unchanged masked, new for changed)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_headers()
|
||||
existing_masked = mcp_provider_entity.masked_headers()
|
||||
|
||||
return {
|
||||
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
|
||||
for key, value in incoming_headers.items()
|
||||
if key in existing_decrypted or value != existing_masked.get(key)
|
||||
}
|
||||
|
||||
def _merge_credentials_with_masked(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str | None,
|
||||
mcp_provider: MCPToolProvider,
|
||||
) -> tuple[
|
||||
str,
|
||||
str | None,
|
||||
]:
|
||||
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
client_id: Client ID from frontend (may be masked)
|
||||
client_secret: Client secret from frontend (may be masked)
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Tuple of (final_client_id, final_client_secret)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||
existing_masked = mcp_provider_entity.masked_credentials()
|
||||
|
||||
# Check if client_id is masked and unchanged
|
||||
final_client_id = client_id
|
||||
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
|
||||
# Use existing decrypted value
|
||||
final_client_id = existing_decrypted.get("client_id", client_id)
|
||||
|
||||
# Check if client_secret is masked and unchanged
|
||||
final_client_secret = client_secret
|
||||
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
|
||||
# Use existing decrypted value
|
||||
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||
|
||||
return final_client_id, final_client_secret
|
||||
|
||||
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
|
||||
"""Build credentials and encrypt sensitive fields."""
|
||||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"client_name": CLIENT_NAME,
|
||||
"is_dynamic_registration": False,
|
||||
}
|
||||
secret_fields = []
|
||||
if client_secret is not None:
|
||||
credentials_data["encrypted_client_secret"] = client_secret
|
||||
secret_fields = ["encrypted_client_secret"]
|
||||
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
return json.dumps({"client_information": client_info})
|
||||
|
||||
@@ -3,9 +3,11 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
@@ -232,40 +234,57 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
||||
user = db_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
icon=db_provider.provider_icon,
|
||||
type=ToolProviderType.MCP,
|
||||
is_team_authorization=db_provider.authed,
|
||||
server_url=db_provider.masked_server_url,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
|
||||
),
|
||||
updated_at=int(db_provider.updated_at.timestamp()),
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
server_identifier=db_provider.server_identifier,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
masked_headers=db_provider.masked_headers,
|
||||
original_headers=db_provider.decrypted_headers,
|
||||
)
|
||||
def mcp_provider_to_user_provider(
|
||||
db_provider: MCPToolProvider,
|
||||
for_list: bool = False,
|
||||
user_name: str | None = None,
|
||||
include_sensitive: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||
if user_name is None:
|
||||
user = db_provider.load_user()
|
||||
user_name = user.name if user else None
|
||||
|
||||
# Convert to entity and use its API response method
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
|
||||
try:
|
||||
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
mcp_tools = []
|
||||
# Add additional fields specific to the transform
|
||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name)
|
||||
response["server_identifier"] = db_provider.server_identifier
|
||||
|
||||
# Convert configuration dict to MCPConfiguration object
|
||||
if "configuration" in response and isinstance(response["configuration"], dict):
|
||||
response["configuration"] = MCPConfiguration(
|
||||
timeout=float(response["configuration"]["timeout"]),
|
||||
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
|
||||
)
|
||||
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
@staticmethod
|
||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
||||
user = mcp_provider.load_user()
|
||||
def mcp_tool_to_user_tool(
|
||||
mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None
|
||||
) -> list[ToolApiEntity]:
|
||||
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
|
||||
if user_name is None:
|
||||
user = mcp_provider.load_user()
|
||||
user_name = user.name if user else "Anonymous"
|
||||
|
||||
return [
|
||||
ToolApiEntity(
|
||||
author=user.name if user else "Anonymous",
|
||||
author=user_name or "Anonymous",
|
||||
name=tool.name,
|
||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
||||
labels=[],
|
||||
output_schema=tool.outputSchema or {},
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
@@ -412,7 +431,7 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
||||
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||
"""
|
||||
Convert MCP JSON schema to tool parameters
|
||||
|
||||
@@ -421,7 +440,7 @@ class ToolTransformService:
|
||||
"""
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
"""Create a ToolParameter instance with given attributes"""
|
||||
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||
@@ -436,7 +455,9 @@ class ToolTransformService:
|
||||
**input_schema_dict,
|
||||
)
|
||||
|
||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
||||
def process_properties(
|
||||
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
|
||||
) -> list[ToolParameter]:
|
||||
"""Process properties recursively"""
|
||||
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
||||
COMPLEX_TYPES = ["array", "object"]
|
||||
|
||||
Reference in New Issue
Block a user