feat: implement MCP specification 2025-06-18 (#25766)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Novice
2025-10-27 17:07:51 +08:00
committed by GitHub
parent b6e0abadab
commit 0ded6303c1
33 changed files with 4863 additions and 1128 deletions

View File

@@ -6,6 +6,7 @@ from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -15,20 +16,21 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@@ -42,7 +44,9 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
@@ -886,29 +890,34 @@ class ToolProviderMCPApi(Resource):
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
user_id=user.id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
)
)
return jsonable_encoder(result)
@setup_required
@login_required
@@ -923,31 +932,43 @@ class ToolProviderMCPApi(Resource):
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@setup_required
@login_required
@@ -958,8 +979,11 @@ class ToolProviderMCPApi(Resource):
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
@@ -976,37 +1000,53 @@ class ToolMCPAuthApi(Resource):
args = parser.parse_args()
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
tenant_id=tenant_id,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@@ -1017,8 +1057,10 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
@@ -1029,9 +1071,12 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
@@ -1041,11 +1086,13 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
@@ -1059,5 +1106,15 @@ class ToolMCPCallbackApi(Resource):
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@@ -0,0 +1,328 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: str | None = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
"""Convert to API response format
Args:
user_name: User name to display
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
"""
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Skip expensive operations when sensitive data is not needed (e.g., list view)
if not include_sensitive:
response["masked_headers"] = {}
response["is_dynamic_registration"] = True
else:
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
"is_dynamic_registration", True
)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" not in credentials:
return None
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
if "encrypted_client_secret" in client_info_data:
client_info_data["client_secret"] = encrypter.decrypt_token(
self.tenant_id, client_info_data["encrypted_client_secret"]
)
return OAuthClientInformation.model_validate(client_info_data)
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
if parsed.path and parsed.path != "/":
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
return {}
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("encrypted_client_secret"):
masked["client_secret"] = self._mask_value(
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
)
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@@ -6,11 +6,15 @@ import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
import httpx
from pydantic import BaseModel, ValidationError
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
@@ -19,21 +23,10 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
@@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = httpx.get(url_for_resource_discovery, headers=headers)
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except httpx.RequestError:
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
@@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
return None
continue
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
raise
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found
def start_authorization(
@@ -213,7 +223,7 @@ def exchange_authorization(
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
@@ -233,7 +243,7 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@@ -246,7 +256,7 @@ def refresh_authorization(
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
@@ -263,10 +273,55 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
try:
response = ssrf_proxy.post(token_url, data=params)
except ssrf_proxy.MaxRetriesExceededError as e:
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json())
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
@@ -283,7 +338,7 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = httpx.post(
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
@@ -294,28 +349,111 @@ def register_client(
def auth(
provider: OAuthClientProvider,
server_url: str,
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except httpx.RequestError as e:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
# Exchange authorization code for tokens
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
)
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
@@ -335,35 +473,69 @@ def auth(
tokens = exchange_authorization(
server_url,
metadata,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
provider_tokens = provider.tokens()
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
server_metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
redirect_url,
provider_id,
tenant_id,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@@ -1,77 +0,0 @@
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> OAuthClientInformation | None:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> OAuthTokens | None:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

191
api/core/mcp/auth_client.py Normal file
View File

@@ -0,0 +1,191 @@
"""
MCP Client with Authentication Retry Support
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry(MCPClient):
"""
An enhanced MCPClient that provides automatic authentication retry.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager with retry support."""
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

View File

@@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
@@ -276,31 +276,34 @@ def sse_client(
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):

View File

@@ -434,45 +434,48 @@ def streamablehttp_client(
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@@ -1,10 +1,13 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
from pydantic import BaseModel
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
@@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError):
pass
class MCPRefreshTokenError(MCPError):
pass

View File

@@ -7,9 +7,9 @@ from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
@@ -18,40 +18,18 @@ class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: str | None = None,
for_list: bool = False,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: ClientSession | None = None
self._streams_context: AbstractContextManager[Any] | None = None
self._session_context: ClientSession | None = None
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
@@ -85,61 +63,42 @@ class MCPClient:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict):
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._initialized or not self._session:
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
@@ -153,6 +112,4 @@ class MCPClient:
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False

View File

@@ -201,11 +201,14 @@ class BaseSession(
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,

View File

@@ -284,7 +284,7 @@ class ClientSession(
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""

View File

@@ -1,13 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@@ -33,6 +26,7 @@ for reference.
LATEST_PROTOCOL_VERSION = "2025-03-26"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
@@ -55,14 +49,22 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
"""
This parameter name is reserved by MCP to allow clients and servers to attach
additional metadata to their notifications.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
@@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Result(BaseModel):
"""Base class for JSON-RPC results."""
model_config = ConfigDict(extra="allow")
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
This result property is reserved by the protocol to allow clients and servers to
attach additional metadata to their responses.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class PaginatedResult(Result):
@@ -186,10 +186,26 @@ class EmptyResult(Result):
"""A response that indicates success but carries no data."""
class Implementation(BaseModel):
"""Describes the name and version of an MCP implementation."""
class BaseMetadata(BaseModel):
"""Base class for entities with name and optional title fields."""
name: str
"""The programmatic name of the entity."""
title: str | None = None
"""
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
even by those unfamiliar with domain-specific terminology.
If not provided, the name should be used for display (except for Tool,
where `annotations.title` should be given precedence over using `name`,
if present).
"""
class Implementation(BaseMetadata):
"""Describes the name and version of an MCP implementation."""
version: str
model_config = ConfigDict(extra="allow")
@@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""
model_config = ConfigDict(extra="allow")
@@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionsCapability(BaseModel):
"""Capability for completions operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel):
"""Capabilities that a server may support."""
@@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
"""Present if the server offers any resources to read."""
tools: ToolsCapability | None = None
"""Present if the server offers any tools to call."""
completions: CompletionsCapability | None = None
"""Present if the server offers autocompletion suggestions for prompts and resources."""
model_config = ConfigDict(extra="allow")
@@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
to begin initialization.
"""
method: Literal["initialize"]
method: Literal["initialize"] = "initialize"
params: InitializeRequestParams
@@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
finished.
"""
method: Literal["notifications/initialized"]
method: Literal["notifications/initialized"] = "notifications/initialized"
params: NotificationParams | None = None
@@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
still alive.
"""
method: Literal["ping"]
method: Literal["ping"] = "ping"
params: RequestParams | None = None
@@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
"""
total: float | None = None
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
model_config = ConfigDict(extra="allow")
@@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
long-running request.
"""
method: Literal["notifications/progress"]
method: Literal["notifications/progress"] = "notifications/progress"
params: ProgressNotificationParams
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
method: Literal["resources/list"] = "resources/list"
class Annotations(BaseModel):
@@ -360,13 +388,11 @@ class Annotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Resource(BaseModel):
class Resource(BaseMetadata):
"""A known resource that the server is capable of reading."""
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None
@@ -379,10 +405,15 @@ class Resource(BaseModel):
This can be used by Hosts to display file sizes and estimate context window usage.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceTemplate(BaseModel):
class ResourceTemplate(BaseMetadata):
"""A template description for resources available on the server."""
uriTemplate: str
@@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource
URIs.
"""
name: str
"""A human-readable name for the type of resource this template refers to."""
description: str | None = None
"""A human-readable description of what this template is for."""
mimeType: str | None = None
@@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
included if all resources matching this template have the same type.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
method: Literal["resources/templates/list"] = "resources/templates/list"
class ListResourceTemplatesResult(PaginatedResult):
@@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
"""Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"]
method: Literal["resources/read"] = "resources/read"
params: ReadResourceRequestParams
@@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
"""The URI of this resource."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -481,7 +519,7 @@ class ResourceListChangedNotification(
of resources it can read from has changed.
"""
method: Literal["notifications/resources/list_changed"]
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
params: NotificationParams | None = None
@@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
whenever a particular resource changes.
"""
method: Literal["resources/subscribe"]
method: Literal["resources/subscribe"] = "resources/subscribe"
params: SubscribeRequestParams
@@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
the server.
"""
method: Literal["resources/unsubscribe"]
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
params: UnsubscribeRequestParams
@@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
changed and may need to be read again.
"""
method: Literal["notifications/resources/updated"]
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
method: Literal["prompts/list"] = "prompts/list"
class PromptArgument(BaseModel):
@@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class Prompt(BaseModel):
class Prompt(BaseMetadata):
"""A prompt or prompt template that the server offers."""
name: str
"""The name of the prompt or prompt template."""
description: str | None = None
"""An optional description of what this prompt provides."""
arguments: list[PromptArgument] | None = None
"""A list of arguments to use for templating the prompt."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""
method: Literal["prompts/get"]
method: Literal["prompts/get"] = "prompts/get"
params: GetPromptRequestParams
@@ -608,6 +648,11 @@ class TextContent(BaseModel):
text: str
"""The text content of the message."""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -623,6 +668,31 @@ class ImageContent(BaseModel):
image types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class AudioContent(BaseModel):
"""Audio content for a message."""
type: Literal["audio"]
data: str
"""The base64-encoded audio data."""
mimeType: str
"""
The MIME type of the audio. Different providers may support different
audio types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model_config = ConfigDict(extra="allow")
@@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceLink(Resource):
"""
A resource that the server is capable of reading, included in a prompt or tool call result.
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
"""
type: Literal["resource_link"]
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
content: ContentBlock
model_config = ConfigDict(extra="allow")
@@ -672,15 +764,14 @@ class PromptListChangedNotification(
of prompts it offers has changed.
"""
method: Literal["notifications/prompts/list_changed"]
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
method: Literal["tools/list"] = "tools/list"
class ToolAnnotations(BaseModel):
@@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Tool(BaseModel):
class Tool(BaseMetadata):
"""Definition for a tool the client can call."""
name: str
"""The name of the tool."""
description: str | None = None
"""A human-readable description of the tool."""
inputSchema: dict[str, Any]
"""A JSON Schema object defining the expected parameters for the tool."""
outputSchema: dict[str, Any] | None = None
"""
An optional JSON Schema object defining the structure of the tool's output
returned in the structuredContent field of a CallToolResult.
"""
annotations: ToolAnnotations | None = None
"""Optional additional tool information."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""
method: Literal["tools/call"]
method: Literal["tools/call"] = "tools/call"
params: CallToolRequestParams
class CallToolResult(Result):
"""The server's response to a tool call."""
content: list[TextContent | ImageContent | EmbeddedResource]
content: list[ContentBlock]
structuredContent: dict[str, Any] | None = None
"""An optional JSON object that represents the structured result of the tool call."""
isError: bool = False
@@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
of tools it offers has changed.
"""
method: Literal["notifications/tools/list_changed"]
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
params: NotificationParams | None = None
@@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""
method: Literal["logging/setLevel"]
method: Literal["logging/setLevel"] = "logging/setLevel"
params: SetLevelRequestParams
@@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
"""The severity of this log message."""
logger: str | None = None
"""An optional name of the logger issuing this message."""
data: Any = None
data: Any
"""
The data to be logged, such as a string message or an object. Any JSON serializable
type is allowed here.
@@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
"""Notification of a log message passed from server to client."""
method: Literal["notifications/message"]
method: Literal["notifications/message"] = "notifications/message"
params: LoggingMessageNotificationParams
@@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
"""A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"]
method: Literal["sampling/createMessage"] = "sampling/createMessage"
params: CreateMessageRequestParams
@@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""
class ResourceReference(BaseModel):
class ResourceTemplateReference(BaseModel):
"""A reference to a resource or resource template definition."""
type: Literal["ref/resource"]
@@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionContext(BaseModel):
"""Additional, optional context for completions."""
arguments: dict[str, str] | None = None
"""Previously-resolved variables in a URI template or prompt."""
model_config = ConfigDict(extra="allow")
class CompleteRequestParams(RequestParams):
"""Parameters for completion requests."""
ref: ResourceReference | PromptReference
ref: ResourceTemplateReference | PromptReference
argument: CompletionArgument
context: CompletionContext | None = None
"""Additional, optional context for completions"""
model_config = ConfigDict(extra="allow")
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""
method: Literal["completion/complete"]
method: Literal["completion/complete"] = "completion/complete"
params: CompleteRequestParams
@@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
method: Literal["roots/list"] = "roots/list"
params: RequestParams | None = None
@@ -1029,6 +1140,11 @@ class Root(BaseModel):
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
params: NotificationParams | None = None
@@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
previously-issued request.
"""
method: Literal["notifications/cancelled"]
method: Literal["notifications/cancelled"] = "notifications/cancelled"
params: CancelledNotificationParams

View File

@@ -217,3 +217,16 @@ class Tool(ABC):
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(
self, variable_name: str, variable_value: Any, stream: bool = False
) -> ToolInvokeMessage:
"""
create a variable message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.VARIABLE,
message=ToolInvokeMessage.VariableMessage(
variable_name=variable_name, variable_value=variable_value, stream=stream
),
)

View File

@@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
@@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {

View File

@@ -1,6 +1,6 @@
import json
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
identity=ToolIdentity(
author=user.name if user else "Anonymous",
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
icon=db_provider.icon,
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
@@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
),
llm=remote_mcp_tool.description or "",
),
output_schema=remote_mcp_tool.outputSchema or {},
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):

View File

@@ -3,12 +3,13 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
class MCPTool(Tool):
@@ -44,40 +45,32 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError
try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
result = self.invoke_remote_mcp_tool(tool_parameters)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
yield self.create_variable_message(k, v)
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
try:
content_json = json.loads(content.text)
yield from self._process_json_content(content_json)
except json.JSONDecodeError:
yield self.create_text_message(content.text)
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
@@ -126,3 +119,44 @@ class MCPTool(Tool):
for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@@ -719,7 +724,9 @@ class ToolManager:
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
@@ -774,17 +781,12 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController.from_db(provider)
@@ -922,16 +924,15 @@ class ToolManager:
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
if mcp_provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
return mcp_provider.provider_icon
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@@ -1,16 +1,13 @@
import json
from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.helper import encrypter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@@ -21,7 +18,7 @@ from .model import Account, App, Tenant
from .types import StringUUID
if TYPE_CHECKING:
from core.mcp.types import Tool as MCPTool
from core.entities.mcp_provider import MCPProviderEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@@ -331,126 +328,36 @@ class MCPToolProvider(TypeBase):
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict[str, Any]:
if not self.encrypted_credentials:
return {}
try:
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except json.JSONDecodeError:
return {}
@property
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> Mapping[str, str] | str:
from core.file import helpers as file_helpers
assert self.icon
try:
return json.loads(self.icon)
except json.JSONDecodeError:
return file_helpers.get_signed_file_url(self.icon)
@property
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def decrypted_headers(self) -> dict[str, Any]:
"""Get decrypted headers for MCP server requests."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(headers_data)
return result
return json.loads(self.encrypted_credentials)
except Exception:
return {}
@property
def masked_headers(self) -> dict[str, Any]:
"""Get masked headers for frontend display."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
def headers(self) -> dict[str, Any]:
if self.encrypted_headers is None:
return {}
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
return result
return json.loads(self.encrypted_headers)
except Exception:
return {}
@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
def tool_dict(self) -> list[dict[str, Any]]:
try:
return json.loads(self.tools) if self.tools else []
except (json.JSONDecodeError, TypeError):
return []
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
def to_entity(self) -> "MCPProviderEntity":
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials)
return MCPProviderEntity.from_db_model(self)
class ToolModelInvoke(TypeBase):

View File

@@ -1,86 +1,118 @@
import hashlib
import json
import logging
from datetime import datetime
from enum import StrEnum
from typing import Any
from urllib.parse import urlparse
from sqlalchemy import or_
from pydantic import BaseModel, Field
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class OAuthDataType(StrEnum):
"""Types of OAuth data that can be saved."""
TOKENS = "tokens"
CLIENT_INFO = "client_info"
CODE_VERIFIER = "code_verifier"
MIXED = "mixed"
class ReconnectResult(BaseModel):
"""Result of reconnecting to an MCP provider"""
authed: bool = Field(description="Whether the provider is authenticated")
tools: str = Field(description="JSON string of tool list")
encrypted_credentials: str = Field(description="JSON string of encrypted credentials")
class ServerUrlValidationResult(BaseModel):
"""Result of server URL validation check"""
needs_validation: bool
validation_passed: bool = False
reconnect_result: ReconnectResult | None = None
encrypted_server_url: str | None = None
server_url_hash: str | None = None
@property
def should_update_server_url(self) -> bool:
"""Check if server URL should be updated based on validation result"""
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class MCPToolManageService:
"""
Service class for managing mcp tools.
"""
"""Service class for managing MCP tools and providers."""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
def __init__(self, session: Session):
self._session = session
# ========== Provider CRUD Operations ==========
def get_provider(
self, *, provider_id: str | None = None, server_identifier: str | None = None, tenant_id: str
) -> MCPToolProvider:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Get MCP provider by ID or server identifier.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
provider_id: Provider ID (UUID)
server_identifier: Server identifier
tenant_id: Tenant ID
Returns:
Dictionary with all headers encrypted
MCPToolProvider instance
Raises:
ValueError: If provider not found
"""
if not headers:
return {}
if server_identifier:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
else:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id
)
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return encrypter_instance.encrypt(headers)
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
)
if not res:
raise ValueError("MCP tool not found")
return res
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID or server identifier."""
if by_server_id:
db_provider = self.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
else:
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return db_provider.to_entity()
@staticmethod
def create_mcp_provider(
def create_provider(
self,
*,
tenant_id: str,
name: str,
server_url: str,
@@ -89,37 +121,30 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
.first()
)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError(f"MCP tool {server_url} already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
"""Create a new MCP provider."""
# Validate URL format
if not self._is_valid_url(server_url):
raise ValueError("Server URL is not valid.")
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# Check for existing provider
self._check_provider_exists(tenant_id, name, server_url_hash, server_identifier)
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
encrypted_credentials = None
if authentication is not None and authentication.client_id:
encrypted_credentials = self._build_and_encrypt_credentials(
authentication.client_id, authentication.client_secret, tenant_id
)
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@@ -127,93 +152,23 @@ class MCPToolManageService:
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
timeout=configuration.timeout,
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
@staticmethod
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
return [
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
for mcp_provider in mcp_providers
]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
try:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
except Exception:
db.session.rollback()
raise
user = mcp_provider.load_user()
if not mcp_provider.icon:
raise ValueError("MCP provider icon is required")
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=user.name if user else "Anonymous",
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
plugin_unique_identifier=mcp_provider.server_identifier,
encrypted_credentials=encrypted_credentials,
)
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
self._session.add(mcp_tool)
self._session.flush()
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
db.session.delete(mcp_tool)
db.session.commit()
@classmethod
def update_mcp_provider(
cls,
def update_provider(
self,
*,
tenant_id: str,
provider_id: str,
name: str,
@@ -222,129 +177,546 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
validation_result: ServerUrlValidationResult | None = None,
) -> None:
"""
Update an MCP provider.
reconnect_result = None
Args:
validation_result: Pre-validation result from validate_server_url_change.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check for duplicate name (excluding current provider)
if name != mcp_provider.name:
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
MCPToolProvider.name == name,
MCPToolProvider.id != provider_id,
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
raise ValueError(f"MCP tool {name} already exists")
# Get URL update data from validation result
encrypted_server_url = None
server_url_hash = None
reconnect_result = None
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
if validation_result and validation_result.encrypted_server_url:
# Use all data from validation result
encrypted_server_url = validation_result.encrypted_server_url
server_url_hash = validation_result.server_url_hash
reconnect_result = validation_result.reconnect_result
try:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.icon = self._prepare_icon(icon, icon_type, icon_background)
mcp_provider.server_identifier = server_identifier
if encrypted_server_url is not None and server_url_hash is not None:
# Update server URL if changed
if encrypted_server_url and server_url_hash:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
if reconnect_result:
mcp_provider.authed = reconnect_result["authed"]
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
mcp_provider.authed = reconnect_result.authed
mcp_provider.tools = reconnect_result.tools
mcp_provider.encrypted_credentials = reconnect_result.encrypted_credentials
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
# Update optional configuration fields
self._update_optional_fields(mcp_provider, configuration)
# Update headers if provided
if headers is not None:
# Merge masked headers from frontend with existing real values
if headers:
# existing decrypted and masked headers
existing_decrypted = mcp_provider.decrypted_headers
existing_masked = mcp_provider.masked_headers
mcp_provider.encrypted_headers = self._process_headers(headers, mcp_provider, tenant_id)
# Build final headers: if value equals masked existing, keep original decrypted value
final_headers: dict[str, str] = {}
for key, incoming_value in headers.items():
if (
key in existing_masked
and key in existing_decrypted
and isinstance(incoming_value, str)
and incoming_value == existing_masked.get(key)
):
# unchanged, use original decrypted value
final_headers[key] = str(existing_decrypted[key])
else:
final_headers[key] = incoming_value
# Update credentials if provided
if authentication and authentication.client_id:
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
# Explicitly clear headers if empty dict passed
mcp_provider.encrypted_headers = None
db.session.commit()
# Flush changes to database
self._session.flush()
except IntegrityError as e:
db.session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
except Exception:
db.session.rollback()
raise
self._handle_integrity_error(e, name, server_url, server_identifier)
@classmethod
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController.from_db(mcp_provider)
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider."""
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant.
Args:
tenant_id: Tenant ID
for_list: If True, return provider ID; if False, return server identifier
include_sensitive: If False, skip expensive decryption operations (default: True for backward compatibility)
"""
from models.account import Account
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
if not mcp_providers:
return []
# Batch query all users to avoid N+1 problem
user_ids = {provider.user_id for provider in mcp_providers}
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
user_name_map = {user.id: user.name for user in users}
return [
ToolTransformService.mcp_provider_to_user_provider(
provider,
for_list=for_list,
user_name=user_name_map.get(provider.user_id),
include_sensitive=include_sensitive,
)
for provider in mcp_providers
]
# ========== Tool Operations ==========
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server."""
# Load provider and convert to entity
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = db_provider.to_entity()
# Verify authentication
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
# Prepare headers with auth token
headers = self._prepare_auth_headers(provider_entity)
# Retrieve tools from remote server
server_url = provider_entity.decrypt_server_url()
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
# Update database with retrieved tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
# Build API response
return self._build_tool_provider_response(db_provider, provider_entity, tools)
# ========== OAuth and Credentials Operations ==========
def update_provider_credentials(
self, *, provider_id: str, tenant_id: str, credentials: dict[str, Any], authed: bool | None = None
) -> None:
"""
Update provider credentials with encryption.
Args:
provider_id: Provider ID
tenant_id: Tenant ID
credentials: Credentials to save
authed: Whether provider is authenticated (None means keep current state)
"""
from core.tools.mcp_tool.provider import MCPToolProviderController
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Encrypt new credentials
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
tenant_id=provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
if not authed:
mcp_provider.tools = "[]"
db.session.commit()
encrypted_credentials = tool_configuration.encrypt(credentials)
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
# Update provider
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = EMPTY_TOOLS_JSON
# Flush changes to database
self._session.flush()
def save_oauth_data(
self, provider_id: str, tenant_id: str, data: dict[str, Any], data_type: OAuthDataType = OAuthDataType.MIXED
) -> None:
"""
Save OAuth-related data (tokens, client info, code verifier).
Args:
provider_id: Provider ID
tenant_id: Tenant ID
data: Data to save (tokens, client info, or code verifier)
data_type: Type of OAuth data to save
"""
# Determine if this makes the provider authenticated
authed = (
data_type == OAuthDataType.TOKENS or (data_type == OAuthDataType.MIXED and "access_token" in data) or None
)
# update_provider_credentials will validate provider existence
self.update_provider_credentials(provider_id=provider_id, tenant_id=tenant_id, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider_id: str, tenant_id: str) -> None:
"""
Clear all credentials for a provider.
Args:
provider_id: Provider ID
tenant_id: Tenant ID
"""
# Get provider from current session
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
# ========== Private Helper Methods ==========
def _check_provider_exists(self, tenant_id: str, name: str, server_url_hash: str, server_identifier: str) -> None:
"""Check if provider with same attributes already exists."""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
if existing_provider.server_url_hash == server_url_hash:
raise ValueError("MCP tool with this server URL already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
def _prepare_icon(self, icon: str, icon_type: str, icon_background: str) -> str:
"""Prepare icon data for storage."""
if icon_type == "emoji":
return json.dumps({"content": icon, "background": icon_background})
return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
"""Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data to encrypt
secret_fields: List of field names to encrypt
tenant_id: Tenant ID for encryption
Returns:
JSON string of encrypted data
"""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create config for secret fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
encrypted_data = encrypter_instance.encrypt(data)
return encrypted_data
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
# All headers are treated as secret
return json.dumps(self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id))
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
headers = provider_entity.decrypt_headers()
tokens = provider_entity.retrieve_tokens()
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
):
"""Retrieve tools from remote MCP server."""
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.list_tools()
def execute_auth_actions(self, auth_result: Any) -> dict[str, str]:
"""
Execute the actions returned by the auth function.
This method processes the AuthResult and performs the necessary database operations.
Args:
auth_result: The result from the auth function
Returns:
The response from the auth result
"""
from core.mcp.entities import AuthAction, AuthActionType
action: AuthAction
for action in auth_result.actions:
if action.provider_id is None or action.tenant_id is None:
continue
if action.action_type == AuthActionType.SAVE_CLIENT_INFO:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO)
elif action.action_type == AuthActionType.SAVE_TOKENS:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS)
elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER:
self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER)
return auth_result.response
def auth_with_actions(
self, provider_entity: MCPProviderEntity, authorization_code: str | None = None
) -> dict[str, str]:
"""
Perform authentication and execute all resulting actions.
This method is used by MCPClientWithAuthRetry for automatic re-authentication.
Args:
provider_entity: The MCP provider entity
authorization_code: Optional authorization code
Returns:
Response dictionary from auth result
"""
auth_result = auth(provider_entity, authorization_code)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
provider_entity = provider.to_entity()
headers = provider_entity.headers
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": "{}",
}
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
"""
# Handle hidden/unchanged URL
if UNCHANGED_SERVER_URL_PLACEHOLDER in new_server_url:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
reconnect_result=reconnect_result,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
return ToolProviderApiEntity(**response)
def _handle_integrity_error(
self, error: IntegrityError, name: str, server_url: str, server_identifier: str
) -> None:
"""Handle database integrity errors with user-friendly messages."""
error_msg = str(error.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""
if not url:
return False
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except (ValueError, TypeError):
return False
def _update_optional_fields(self, mcp_provider: MCPToolProvider, configuration: MCPConfiguration) -> None:
"""Update optional configuration fields using setattr for cleaner code."""
field_mapping = {"timeout": configuration.timeout, "sse_read_timeout": configuration.sse_read_timeout}
for field, value in field_mapping.items():
if value is not None:
setattr(mcp_provider, field, value)
def _process_headers(self, headers: dict[str, str], mcp_provider: MCPToolProvider, tenant_id: str) -> str | None:
"""Process headers update, handling empty dict to clear headers."""
if not headers:
return None
# Merge with existing headers to preserve masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
return self._prepare_encrypted_dict(final_headers, tenant_id)
def _process_credentials(
self, authentication: MCPAuthentication, mcp_provider: MCPToolProvider, tenant_id: str
) -> str:
"""Process credentials update, handling masked values."""
# Merge with existing credentials
final_client_id, final_client_secret = self._merge_credentials_with_masked(
authentication.client_id, authentication.client_secret, mcp_provider
)
# Build and encrypt
return self._build_and_encrypt_credentials(final_client_id, final_client_secret, tenant_id)
def _merge_headers_with_masked(
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
) -> dict[str, str]:
"""Merge incoming headers with existing ones, preserving unchanged masked values.
Args:
incoming_headers: Headers from frontend (may contain masked values)
mcp_provider: The MCP provider instance
Returns:
Final headers dict with proper values (original for unchanged masked, new for changed)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_headers()
existing_masked = mcp_provider_entity.masked_headers()
return {
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[
str,
str | None,
]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
existing_masked = mcp_provider_entity.masked_credentials()
# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)
# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
return final_client_id, final_client_secret
def _build_and_encrypt_credentials(self, client_id: str, client_secret: str | None, tenant_id: str) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_name": CLIENT_NAME,
"is_dynamic_registration": False,
}
secret_fields = []
if client_secret is not None:
credentials_data["encrypted_client_secret"] = client_secret
secret_fields = ["encrypted_client_secret"]
client_info = self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
return json.dumps({"client_information": client_info})

View File

@@ -3,9 +3,11 @@ import logging
from collections.abc import Mapping
from typing import Any, Union
from pydantic import ValidationError
from yarl import URL
from configs import dify_config
from core.entities.mcp_provider import MCPConfiguration
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
@@ -232,40 +234,57 @@ class ToolTransformService:
)
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
user = db_provider.load_user()
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=user.name if user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
)
def mcp_provider_to_user_provider(
db_provider: MCPToolProvider,
for_list: bool = False,
user_name: str | None = None,
include_sensitive: bool = True,
) -> ToolProviderApiEntity:
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None:
user = db_provider.load_user()
user_name = user.name if user else None
# Convert to entity and use its API response method
provider_entity = db_provider.to_entity()
response = provider_entity.to_api_response(user_name=user_name, include_sensitive=include_sensitive)
try:
mcp_tools = [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
except (ValidationError, json.JSONDecodeError):
mcp_tools = []
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, mcp_tools, user_name=user_name)
response["server_identifier"] = db_provider.server_identifier
# Convert configuration dict to MCPConfiguration object
if "configuration" in response and isinstance(response["configuration"], dict):
response["configuration"] = MCPConfiguration(
timeout=float(response["configuration"]["timeout"]),
sse_read_timeout=float(response["configuration"]["sse_read_timeout"]),
)
return ToolProviderApiEntity(**response)
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
user = mcp_provider.load_user()
def mcp_tool_to_user_tool(
mcp_provider: MCPToolProvider, tools: list[MCPTool], user_name: str | None = None
) -> list[ToolApiEntity]:
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None:
user = mcp_provider.load_user()
user_name = user.name if user else "Anonymous"
return [
ToolApiEntity(
author=user.name if user else "Anonymous",
author=user_name or "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
output_schema=tool.outputSchema or {},
)
for tool in tools
]
@@ -412,7 +431,7 @@ class ToolTransformService:
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
@@ -421,7 +440,7 @@ class ToolTransformService:
"""
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
@@ -436,7 +455,9 @@ class ToolTransformService:
**input_schema_dict,
)
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
def process_properties(
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
) -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]

View File

@@ -20,12 +20,21 @@ class TestMCPToolManageService:
patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
):
# Setup default mock returns
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_encrypter.encrypt_token.return_value = "encrypted_server_url"
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = {
"id": "test_id",
"name": "test_name",
"type": ToolProviderType.MCP,
}
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test Description", zh_Hans="测试描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test Label", zh_Hans="测试标签"),
labels=[],
tools=[],
)
yield {
"encrypter": mock_encrypter,
@@ -104,9 +113,9 @@ class TestMCPToolManageService:
mcp_provider = MCPToolProvider(
tenant_id=tenant_id,
name=fake.company(),
server_identifier=fake.uuid4(),
server_identifier=str(fake.uuid4()),
server_url="encrypted_server_url",
server_url_hash=fake.sha256(),
server_url_hash=str(fake.sha256()),
user_id=user_id,
authed=False,
tools="[]",
@@ -144,7 +153,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider.id, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@@ -154,8 +166,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier
@@ -177,11 +187,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(non_existent_id, tenant.id)
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@@ -210,8 +223,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider1.id, tenant2.id)
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies
@@ -235,7 +251,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider.server_identifier, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@@ -245,8 +264,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.name == mcp_provider.name
@@ -268,11 +285,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_identifier = fake.uuid4()
non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(non_existent_identifier, tenant.id)
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@@ -301,8 +321,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider1.server_identifier, tenant2.id)
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
@@ -322,15 +345,30 @@ class TestMCPToolManageService:
)
# Setup mocks for provider creation
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url"
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.return_value = {
"id": "new_provider_id",
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
}
mock_external_service_dependencies[
"tool_transform_service"
].mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="new_provider_id",
author=account.name,
name="Test MCP Provider",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test MCP Provider Description", zh_Hans="测试MCP提供者描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test MCP Provider", zh_Hans="测试MCP提供者"),
labels=[],
tools=[],
)
# Act: Execute the method under test
result = MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example.com/mcp",
@@ -339,14 +377,16 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Assert: Verify the expected outcomes
assert result is not None
assert result["name"] == "Test MCP Provider"
assert result["type"] == ToolProviderType.MCP
assert result.name == "Test MCP Provider"
assert result.type == ToolProviderType.MCP
# Verify database state
from extensions.ext_database import db
@@ -386,7 +426,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example1.com/mcp",
@@ -395,13 +439,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate name
with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider", # Duplicate name
server_url="https://example2.com/mcp",
@@ -410,8 +456,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_url(
@@ -432,7 +480,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example.com/mcp",
@@ -441,13 +493,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server URL
with pytest.raises(ValueError, match="MCP tool https://example.com/mcp already exists"):
MCPToolManageService.create_mcp_provider(
with pytest.raises(ValueError, match="MCP tool with this server URL already exists"):
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example.com/mcp", # Duplicate URL
@@ -456,8 +510,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_identifier(
@@ -478,7 +534,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example1.com/mcp",
@@ -487,13 +547,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server identifier
with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example2.com/mcp",
@@ -502,8 +564,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_123", # Duplicate identifier
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
@@ -543,23 +607,59 @@ class TestMCPToolManageService:
db.session.commit()
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
{"id": provider3.id, "name": provider3.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Alpha Provider Description", zh_Hans="Alpha提供者描述"),
icon={"type": "emoji", "content": "🅰️"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Beta Provider Description", zh_Hans="Beta提供者描述"),
icon={"type": "emoji", "content": "🅱️"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider3.id,
author=account.name,
name=provider3.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Gamma Provider Description", zh_Hans="Gamma提供者描述"),
icon={"type": "emoji", "content": "Γ"},
label=I18nObject(en_US=provider3.name, zh_Hans=provider3.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 3
# Verify correct ordering by name
assert result[0]["name"] == "Alpha Provider"
assert result[1]["name"] == "Beta Provider"
assert result[2]["name"] == "Gamma Provider"
assert result[0].name == "Alpha Provider"
assert result[1].name == "Beta Provider"
assert result[2].name == "Gamma Provider"
# Verify mock interactions
assert (
@@ -584,7 +684,10 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=False)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes
assert result is not None
@@ -624,20 +727,46 @@ class TestMCPToolManageService:
)
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account1.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 1 Description", zh_Hans="提供者1描述"),
icon={"type": "emoji", "content": "1"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account2.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 2 Description", zh_Hans="提供者2描述"),
icon={"type": "emoji", "content": "2"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test for both tenants
result1 = MCPToolManageService.retrieve_mcp_tools(tenant1.id, for_list=True)
result2 = MCPToolManageService.retrieve_mcp_tools(tenant2.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
# Assert: Verify tenant isolation
assert len(result1) == 1
assert len(result2) == 1
assert result1[0]["id"] == provider1.id
assert result2[0]["id"] == provider2.id
assert result1[0].id == provider1.id
assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies
@@ -661,17 +790,20 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient and its context manager
mock_tools = [
@@ -683,13 +815,16 @@ class TestMCPToolManageService:
)(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
assert result is not None
@@ -705,16 +840,8 @@ class TestMCPToolManageService:
assert mcp_provider.updated_at is not None
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
# MCPClientWithAuthRetry is called with different parameters
mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies
@@ -737,7 +864,10 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = False
mcp_provider.tools = "[]"
@@ -745,20 +875,23 @@ class TestMCPToolManageService:
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Please auth the tool first"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
@@ -786,32 +919,38 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@@ -840,7 +979,8 @@ class TestMCPToolManageService:
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test
MCPToolManageService.delete_mcp_tool(tenant.id, mcp_provider.id)
service = MCPToolManageService(db.session())
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
# Provider should be deleted from database
@@ -862,11 +1002,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant.id, non_existent_id)
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
"""
@@ -893,8 +1036,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant2.id, mcp_provider1.id)
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1
from extensions.ext_database import db
@@ -929,7 +1075,10 @@ class TestMCPToolManageService:
db.session.commit()
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
@@ -938,8 +1087,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@@ -953,70 +1104,10 @@ class TestMCPToolManageService:
# Verify icon was updated
import json
icon_data = json.loads(mcp_provider.icon)
icon_data = json.loads(mcp_provider.icon or "{}")
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
def test_update_mcp_provider_with_server_url_change(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful update of MCP provider with server URL change.
This test verifies:
- Proper handling of server URL changes
- Correct reconnection logic
- Database state updates
- External service integration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create MCP provider
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
from extensions.ext_database import db
db.session.commit()
# Mock the reconnection method
with patch.object(MCPToolManageService, "_re_connect_mcp_provider") as mock_reconnect:
mock_reconnect.return_value = {
"authed": True,
"tools": '[{"name": "test_tool"}]',
"encrypted_credentials": "{}",
}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
server_url="https://new-example.com/mcp",
icon="🚀",
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
assert mcp_provider.name == "Updated MCP Provider"
assert mcp_provider.server_identifier == "updated_identifier_123"
assert mcp_provider.timeout == 45.0
assert mcp_provider.sse_read_timeout == 400.0
assert mcp_provider.updated_at is not None
# Verify reconnection was called
mock_reconnect.assert_called_once_with("https://new-example.com/mcp", mcp_provider.id, tenant.id)
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when updating MCP provider with duplicate name.
@@ -1048,8 +1139,12 @@ class TestMCPToolManageService:
db.session.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
MCPToolManageService.update_mcp_provider(
service.update_provider(
tenant_id=tenant.id,
provider_id=provider2.id,
name="First Provider", # Duplicate name
@@ -1058,8 +1153,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="unique_identifier",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_update_mcp_provider_credentials_success(
@@ -1094,19 +1191,25 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
credentials={"new_key": "new_value"},
authed=True,
)
# Assert: Verify the expected outcomes
@@ -1117,7 +1220,7 @@ class TestMCPToolManageService:
# Verify credentials were encrypted and merged
import json
credentials = json.loads(mcp_provider.encrypted_credentials)
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
assert "existing_key" in credentials
assert "new_key" in credentials
@@ -1152,19 +1255,25 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
credentials={"new_key": "new_value"},
authed=False,
)
# Assert: Verify the expected outcomes
@@ -1199,41 +1308,37 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
assert result is not None
assert result["authed"] is True
assert result["tools"] is not None
assert result["encrypted_credentials"] == "{}"
assert result.authed is True
assert result.tools is not None
assert result.encrypted_credentials == "{}"
# Verify tools were properly serialized
import json
tools_data = json.loads(result["tools"])
tools_data = json.loads(result.tools)
assert len(tools_data) == 2
assert tools_data[0]["name"] == "test_tool_1"
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@@ -1256,22 +1361,26 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
assert result is not None
assert result["authed"] is False
assert result["tools"] == "[]"
assert result["encrypted_credentials"] == "{}"
assert result.authed is False
assert result.tools == "[]"
assert result.encrypted_credentials == "{}"
def test_re_connect_mcp_provider_connection_error(
self, db_session_with_containers, mock_external_service_dependencies
@@ -1295,12 +1404,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)

View File

@@ -0,0 +1,740 @@
"""Unit tests for MCP OAuth authentication flow."""
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_retrieve_redis_state,
auth,
check_support_resource_discovery,
discover_oauth_metadata,
exchange_authorization,
generate_pkce_challenge,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
class TestPKCEGeneration:
"""Test PKCE challenge generation."""
def test_generate_pkce_challenge(self):
"""Test PKCE challenge and verifier generation."""
code_verifier, code_challenge = generate_pkce_challenge()
# Verify format - should be URL-safe base64 without padding
assert "=" not in code_verifier
assert "+" not in code_verifier
assert "/" not in code_verifier
assert "=" not in code_challenge
assert "+" not in code_challenge
assert "/" not in code_challenge
# Verify length
assert len(code_verifier) > 40 # Should be around 54 characters
assert len(code_challenge) > 40 # Should be around 43 characters
def test_generate_pkce_challenge_uniqueness(self):
"""Test that PKCE generation produces unique values."""
results = set()
for _ in range(10):
code_verifier, code_challenge = generate_pkce_challenge()
results.add((code_verifier, code_challenge))
# All should be unique
assert len(results) == 10
class TestRedisStateManagement:
"""Test Redis state management functions."""
@patch("core.mcp.auth.auth_flow.redis_client")
def test_create_secure_redis_state(self, mock_redis):
"""Test creating secure Redis state."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
state_key = _create_secure_redis_state(state_data)
# Verify state key format
assert len(state_key) > 20 # Should be a secure random token
# Verify Redis call
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
assert state_data.model_dump_json() in call_args[0][2]
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_success(self, mock_redis):
"""Test retrieving state from Redis."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_redis.get.return_value = state_data.model_dump_json()
result = _retrieve_redis_state("test-state-key")
# Verify result
assert result.provider_id == "test-provider"
assert result.tenant_id == "test-tenant"
assert result.server_url == "https://example.com"
# Verify Redis calls
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_not_found(self, mock_redis):
"""Test retrieving non-existent state from Redis."""
mock_redis.get.return_value = None
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("nonexistent-key")
assert "State parameter has expired or does not exist" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_invalid_json(self, mock_redis):
"""Test retrieving invalid JSON state from Redis."""
mock_redis.get.return_value = '{"invalid": json}'
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("test-key")
assert "Invalid state parameter" in str(exc_info.value)
# State should still be deleted
mock_redis.delete.assert_called_once()
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com")
assert supported is False
assert auth_url == ""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is None
class TestAuthorizationFlow:
"""Test authorization flow functions."""
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
def test_start_authorization_with_metadata(self, mock_create_state):
"""Test starting authorization with metadata."""
mock_create_state.return_value = "secure-state-key"
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
code_challenge_methods_supported=["S256"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Verify URL format
assert auth_url.startswith("https://auth.example.com/authorize?")
assert "response_type=code" in auth_url
assert "client_id=test-client-id" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
assert "state=secure-state-key" in auth_url
# Verify code verifier
assert len(code_verifier) > 40
# Verify state was stored
mock_create_state.assert_called_once()
state_data = mock_create_state.call_args[0][0]
assert state_data.provider_id == "provider-id"
assert state_data.tenant_id == "tenant-id"
assert state_data.code_verifier == code_verifier
def test_start_authorization_without_metadata(self):
"""Test starting authorization without metadata."""
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
mock_create_state.return_value = "secure-state-key"
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
None,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Should use default authorization endpoint
assert auth_url.startswith("https://api.example.com/authorize?")
def test_start_authorization_invalid_metadata(self):
"""Test starting authorization with invalid metadata."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["token"], # No "code" support
code_challenge_methods_supported=["plain"], # No "S256" support
)
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
assert "does not support response type code" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
tokens = exchange_authorization(
"https://api.example.com",
metadata,
client_info,
"auth-code-123",
"code-verifier-xyz",
"https://redirect.example.com",
)
assert tokens.access_token == "new-access-token"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"client_id": "test-client-id",
"client_secret": "test-secret",
"code": "auth-code-123",
"code_verifier": "code-verifier-xyz",
"redirect_uri": "https://redirect.example.com",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
mock_response.is_success = False
mock_response.status_code = 400
mock_post.return_value = mock_response
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
exchange_authorization(
"https://api.example.com",
None,
client_info,
"invalid-code",
"code-verifier",
"https://redirect.example.com",
)
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["refresh_token"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
assert tokens.access_token == "refreshed-access-token"
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "refresh_token",
"client_id": "test-client-id",
"refresh_token": "old-refresh-token",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"client_name": "Dify",
"redirect_uris": ["https://redirect.example.com"],
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
grant_types=["authorization_code"],
response_types=["code"],
)
client_info = register_client("https://api.example.com", metadata, client_metadata)
assert isinstance(client_info, OAuthClientInformationFull)
assert client_info.client_id == "new-client-id"
assert client_info.client_secret == "new-client-secret"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/register",
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
def test_register_client_no_endpoint(self):
"""Test client registration when no endpoint available."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint=None,
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
with pytest.raises(ValueError) as exc_info:
register_client("https://api.example.com", metadata, client_metadata)
assert "does not support dynamic client registration" in str(exc_info.value)
class TestCallbackHandling:
"""Test OAuth callback handling."""
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
"""Test successful callback handling."""
# Setup state
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
)
mock_exchange.return_value = tokens
# Setup service
mock_service = Mock()
state_result, tokens_result = handle_callback("state-key", "auth-code")
assert state_result == state_data
assert tokens_result == tokens
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
mock_exchange.assert_called_once_with(
"https://api.example.com",
None,
state_data.client_information,
"auth-code",
"test-verifier",
"https://redirect.example.com",
)
# Note: handle_callback no longer saves tokens directly, it just returns them
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
class TestAuthOrchestration:
"""Test the main auth orchestration function."""
@pytest.fixture
def mock_provider(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "provider-id"
provider.tenant_id = "tenant-id"
provider.decrypt_server_url.return_value = "https://api.example.com"
provider.client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
provider.redirect_url = "https://redirect.example.com"
provider.retrieve_client_information.return_value = None
provider.retrieve_tokens.return_value = None
return provider
@pytest.fixture
def mock_service(self):
"""Create a mock MCP service."""
return Mock()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow.register_client")
@patch("core.mcp.auth.auth_flow.start_authorization")
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider)
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify that the result contains the correct actions
assert len(result.actions) == 2
# Check for SAVE_CLIENT_INFO action
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
assert client_info_action.provider_id == "provider-id"
assert client_info_action.tenant_id == "tenant-id"
# Check for SAVE_CODE_VERIFIER action
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
assert verifier_action.data == {"code_verifier": "code-verifier"}
assert verifier_action.provider_id == "provider-id"
assert verifier_action.tenant_id == "tenant-id"
# Verify calls
mock_register.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Setup existing client
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
# Setup state retrieval
state_data = OAuthCallbackState(
provider_id="provider-id",
tenant_id="tenant-id",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="existing-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
# auth() now returns AuthResult, not a dict
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.refresh_authorization")
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
"""Test auth flow for refreshing tokens."""
# Setup existing client and tokens
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
mock_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="old-token",
token_type="Bearer",
expires_in=0,
refresh_token="refresh-token",
)
# Setup refresh
new_tokens = OAuthTokens(
access_token="refreshed-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
result = auth(mock_provider)
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == new_tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
# Verify refresh was called
mock_refresh.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@@ -0,0 +1,239 @@
"""Unit tests for MCP entities module."""
from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
class TestProtocolVersions:
"""Test protocol version constants."""
def test_supported_protocol_versions(self):
"""Test supported protocol versions list."""
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
def test_latest_protocol_version_is_supported(self):
"""Test that latest protocol version is in supported versions."""
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
class TestRequestContext:
"""Test RequestContext dataclass."""
def test_request_context_creation(self):
"""Test creating a RequestContext instance."""
mock_session = Mock(spec=BaseSession)
mock_lifespan = {"key": "value"}
mock_meta = RequestParams.Meta(progressToken="test-token")
context = RequestContext(
request_id="test-request-123",
meta=mock_meta,
session=mock_session,
lifespan_context=mock_lifespan,
)
assert context.request_id == "test-request-123"
assert context.meta == mock_meta
assert context.session == mock_session
assert context.lifespan_context == mock_lifespan
def test_request_context_with_none_meta(self):
"""Test creating RequestContext with None meta."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id=42, # Can be int or string
meta=None,
session=mock_session,
lifespan_context=None,
)
assert context.request_id == 42
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_attributes(self):
"""Test RequestContext attributes are accessible."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context=None,
)
# Verify attributes are accessible
assert hasattr(context, "request_id")
assert hasattr(context, "meta")
assert hasattr(context, "session")
assert hasattr(context, "lifespan_context")
# Verify values
assert context.request_id == "test-123"
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_generic_typing(self):
"""Test RequestContext with different generic types."""
# Create a mock session with specific type
mock_session = Mock(spec=BaseSession)
# Create context with string lifespan context
context_str = RequestContext[BaseSession, str](
request_id="test-1",
meta=None,
session=mock_session,
lifespan_context="string-context",
)
assert isinstance(context_str.lifespan_context, str)
# Create context with dict lifespan context
context_dict = RequestContext[BaseSession, dict](
request_id="test-2",
meta=None,
session=mock_session,
lifespan_context={"key": "value"},
)
assert isinstance(context_dict.lifespan_context, dict)
# Create context with custom object lifespan context
class CustomLifespan:
def __init__(self, data):
self.data = data
custom_lifespan = CustomLifespan("test-data")
context_custom = RequestContext[BaseSession, CustomLifespan](
request_id="test-3",
meta=None,
session=mock_session,
lifespan_context=custom_lifespan,
)
assert isinstance(context_custom.lifespan_context, CustomLifespan)
assert context_custom.lifespan_context.data == "test-data"
def test_request_context_with_progress_meta(self):
"""Test RequestContext with progress metadata."""
mock_session = Mock(spec=BaseSession)
progress_meta = RequestParams.Meta(progressToken="progress-123")
context = RequestContext(
request_id="req-456",
meta=progress_meta,
session=mock_session,
lifespan_context=None,
)
assert context.meta is not None
assert context.meta.progressToken == "progress-123"
def test_request_context_equality(self):
"""Test RequestContext equality comparison."""
mock_session1 = Mock(spec=BaseSession)
mock_session2 = Mock(spec=BaseSession)
context1 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context2 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context3 = RequestContext(
request_id="test-456",
meta=None,
session=mock_session1,
lifespan_context="context",
)
# Same values should be equal
assert context1 == context2
# Different request_id should not be equal
assert context1 != context3
# Different session should not be equal
context4 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session2,
lifespan_context="context",
)
assert context1 != context4
def test_request_context_repr(self):
"""Test RequestContext string representation."""
mock_session = Mock(spec=BaseSession)
mock_session.__repr__ = Mock(return_value="<MockSession>")
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context={"data": "test"},
)
repr_str = repr(context)
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@@ -0,0 +1,205 @@
"""Unit tests for MCP error classes."""
import pytest
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
class TestMCPError:
"""Test MCPError base exception class."""
def test_mcp_error_creation(self):
"""Test creating MCPError instance."""
error = MCPError("Test error message")
assert str(error) == "Test error message"
assert isinstance(error, Exception)
def test_mcp_error_inheritance(self):
"""Test MCPError inherits from Exception."""
error = MCPError()
assert isinstance(error, Exception)
assert type(error).__name__ == "MCPError"
def test_mcp_error_with_empty_message(self):
"""Test MCPError with empty message."""
error = MCPError()
assert str(error) == ""
def test_mcp_error_raise(self):
"""Test raising MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPError("Something went wrong")
assert str(exc_info.value) == "Something went wrong"
class TestMCPConnectionError:
"""Test MCPConnectionError exception class."""
def test_mcp_connection_error_creation(self):
"""Test creating MCPConnectionError instance."""
error = MCPConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_inheritance(self):
"""Test MCPConnectionError inheritance chain."""
error = MCPConnectionError()
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_raise(self):
"""Test raising MCPConnectionError."""
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPConnectionError("Unable to connect to server")
assert str(exc_info.value) == "Unable to connect to server"
def test_mcp_connection_error_catch_as_mcp_error(self):
"""Test catching MCPConnectionError as MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPConnectionError("Connection issue")
assert isinstance(exc_info.value, MCPConnectionError)
assert str(exc_info.value) == "Connection issue"
class TestMCPAuthError:
"""Test MCPAuthError exception class."""
def test_mcp_auth_error_creation(self):
"""Test creating MCPAuthError instance."""
error = MCPAuthError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_inheritance(self):
"""Test MCPAuthError inheritance chain."""
error = MCPAuthError()
assert isinstance(error, MCPAuthError)
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_raise(self):
"""Test raising MCPAuthError."""
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Invalid credentials")
assert str(exc_info.value) == "Invalid credentials"
def test_mcp_auth_error_catch_hierarchy(self):
"""Test catching MCPAuthError at different levels."""
# Catch as MCPAuthError
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Auth specific error")
assert str(exc_info.value) == "Auth specific error"
# Catch as MCPConnectionError
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPAuthError("Auth connection error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth connection error"
# Catch as MCPError
with pytest.raises(MCPError) as exc_info:
raise MCPAuthError("Auth base error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth base error"
class TestErrorHierarchy:
"""Test the complete error hierarchy."""
def test_exception_hierarchy(self):
"""Test the complete exception hierarchy."""
# Create instances
base_error = MCPError("base")
connection_error = MCPConnectionError("connection")
auth_error = MCPAuthError("auth")
# Test type relationships
assert not isinstance(base_error, MCPConnectionError)
assert not isinstance(base_error, MCPAuthError)
assert isinstance(connection_error, MCPError)
assert not isinstance(connection_error, MCPAuthError)
assert isinstance(auth_error, MCPError)
assert isinstance(auth_error, MCPConnectionError)
def test_error_handling_patterns(self):
"""Test common error handling patterns."""
def raise_auth_error():
raise MCPAuthError("401 Unauthorized")
def raise_connection_error():
raise MCPConnectionError("Connection timeout")
def raise_base_error():
raise MCPError("Generic error")
# Pattern 1: Catch specific errors first
errors_caught = []
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
try:
error_func()
except MCPAuthError:
errors_caught.append("auth")
except MCPConnectionError:
errors_caught.append("connection")
except MCPError:
errors_caught.append("base")
assert errors_caught == ["auth", "connection", "base"]
# Pattern 2: Catch all as base error
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
with pytest.raises(MCPError) as exc_info:
error_func()
assert isinstance(exc_info.value, MCPError)
def test_error_with_cause(self):
"""Test errors with cause (chained exceptions)."""
original_error = ValueError("Original error")
def raise_chained_error():
try:
raise original_error
except ValueError as e:
raise MCPConnectionError("Connection failed") from e
with pytest.raises(MCPConnectionError) as exc_info:
raise_chained_error()
assert str(exc_info.value) == "Connection failed"
assert exc_info.value.__cause__ == original_error
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = MCPError("Test message")
error2 = MCPError("Test message")
error3 = MCPError("Different message")
# Errors are not equal even with same message (different instances)
assert error1 != error2
assert error1 != error3
# But they have the same type
assert type(error1) == type(error2) == type(error3)
def test_error_representation(self):
"""Test error string representation."""
base_error = MCPError("Base error message")
connection_error = MCPConnectionError("Connection error message")
auth_error = MCPAuthError("Auth error message")
assert repr(base_error) == "MCPError('Base error message')"
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
assert repr(auth_error) == "MCPAuthError('Auth error message')"

View File

@@ -0,0 +1,382 @@
"""Unit tests for MCP client."""
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.mcp.error import MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
class TestMCPClient:
"""Test suite for MCPClient."""
def test_init(self):
"""Test client initialization."""
client = MCPClient(
server_url="http://test.example.com/mcp",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
)
assert client.server_url == "http://test.example.com/mcp"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client._session is None
assert isinstance(client._exit_stack, ExitStack)
assert client._initialized is False
def test_init_defaults(self):
"""Test client initialization with defaults."""
client = MCPClient(server_url="http://test.example.com")
assert client.server_url == "http://test.example.com"
assert client.headers == {}
assert client.timeout is None
assert client.sse_read_timeout is None
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
"""Test initialization with MCP URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/mcp")
client._initialize()
# Verify streamable client was called
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
"""Test initialization with SSE URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/sse")
client._initialize()
# Verify SSE client was called
mock_sse_client.assert_called_once_with(
url="http://test.example.com/sse",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_unknown_method_fallback_to_sse(
self, mock_client_session, mock_streamable_client, mock_sse_client
):
"""Test initialization with unknown method falls back to SSE."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify SSE client was tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_not_called()
# Verify session was created
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
"""Test initialization falls back from SSE to MCP on connection error."""
# Setup SSE to fail
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
# Setup MCP to succeed
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify both were tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_called_once()
# Verify session was created with MCP
assert client._session == mock_session
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
"""Test connect_server with MCP method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_streamable_client, "mcp")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
"""Test connect_server with SSE method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_sse_client, "sse")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
client._session = mock_session
result = client.list_tools()
assert result == expected_tools
mock_session.list_tools.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False,
)
mock_session.call_tool.return_value = expected_result
client._session = mock_session
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
client.cleanup()
mock_exit_stack.close.assert_called_once()
assert client._session is None
assert client._initialized is False
def test_cleanup_with_error(self):
"""Test cleanup method with error."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
mock_exit_stack.close.side_effect = Exception("Cleanup error")
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
with pytest.raises(ValueError) as exc_info:
client.cleanup()
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
assert client._session is None
assert client._initialized is False
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
"""Test full context manager flow."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
with MCPClient(server_url="http://test.example.com/mcp") as client:
assert client._initialized is True
assert client._session == mock_session
# Test tool operations
tools = client.list_tools()
assert tools == expected_tools
# After exit, should be cleaned up
assert client._initialized is False
assert client._session is None
def test_headers_passed_to_clients(self):
"""Test that headers are properly passed to underlying clients."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
}
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(
server_url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)
client._initialize()
# Verify headers were passed
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)

View File

@@ -0,0 +1,492 @@
"""Unit tests for MCP types module."""
import pytest
from pydantic import ValidationError
from core.mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
PARSE_ERROR,
SERVER_LATEST_PROTOCOL_VERSION,
Annotations,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
CompleteRequest,
CompleteRequestParams,
CompleteResult,
Completion,
CompletionArgument,
CompletionContext,
ErrorData,
ImageContent,
Implementation,
InitializeRequest,
InitializeRequestParams,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsRequest,
ListToolsResult,
OAuthClientInformation,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
PingRequest,
ProgressNotification,
ProgressNotificationParams,
PromptReference,
RequestParams,
ResourceTemplateReference,
Result,
ServerCapabilities,
TextContent,
Tool,
ToolAnnotations,
)
class TestConstants:
"""Test module constants."""
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):
"""Test JSON-RPC error code constants."""
assert PARSE_ERROR == -32700
assert INVALID_REQUEST == -32600
assert METHOD_NOT_FOUND == -32601
assert INVALID_PARAMS == -32602
assert INTERNAL_ERROR == -32603
class TestRequestParams:
"""Test RequestParams and related classes."""
def test_request_params_basic(self):
"""Test basic RequestParams creation."""
params = RequestParams()
assert params.meta is None
def test_request_params_with_meta(self):
"""Test RequestParams with meta."""
meta = RequestParams.Meta(progressToken="test-token")
params = RequestParams(_meta=meta)
assert params.meta is not None
assert params.meta.progressToken == "test-token"
def test_request_params_meta_extra_fields(self):
"""Test RequestParams.Meta allows extra fields."""
meta = RequestParams.Meta(progressToken="token", customField="value")
assert meta.progressToken == "token"
assert meta.customField == "value" # type: ignore
def test_request_params_serialization(self):
"""Test RequestParams serialization with _meta alias."""
meta = RequestParams.Meta(progressToken="test")
params = RequestParams(_meta=meta)
# Model dump should use the alias
dumped = params.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] is not None
assert dumped["_meta"]["progressToken"] == "test"
class TestJSONRPCMessages:
"""Test JSON-RPC message types."""
def test_jsonrpc_request(self):
"""Test JSONRPCRequest creation and validation."""
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
assert request.jsonrpc == "2.0"
assert request.id == "test-123"
assert request.method == "test_method"
assert request.params == {"key": "value"}
def test_jsonrpc_request_numeric_id(self):
"""Test JSONRPCRequest with numeric ID."""
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
assert request.id == 123
def test_jsonrpc_notification(self):
"""Test JSONRPCNotification creation."""
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
assert notification.jsonrpc == "2.0"
assert notification.method == "notification_method"
assert not hasattr(notification, "id") # Notifications don't have ID
def test_jsonrpc_response(self):
"""Test JSONRPCResponse creation."""
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
assert response.jsonrpc == "2.0"
assert response.id == "req-123"
assert response.result == {"success": True}
def test_jsonrpc_error(self):
"""Test JSONRPCError creation."""
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
assert error.jsonrpc == "2.0"
assert error.id == "req-123"
assert error.error.code == INVALID_PARAMS
assert error.error.message == "Invalid parameters"
assert error.error.data == {"field": "missing"}
def test_jsonrpc_message_parsing(self):
"""Test JSONRPCMessage parsing different message types."""
# Parse request
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
msg = JSONRPCMessage.model_validate_json(request_json)
assert isinstance(msg.root, JSONRPCRequest)
# Parse response
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
msg = JSONRPCMessage.model_validate_json(response_json)
assert isinstance(msg.root, JSONRPCResponse)
# Parse error
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
msg = JSONRPCMessage.model_validate_json(error_json)
assert isinstance(msg.root, JSONRPCError)
class TestCapabilities:
"""Test capability classes."""
def test_client_capabilities(self):
"""Test ClientCapabilities creation."""
caps = ClientCapabilities(
experimental={"feature": {"enabled": True}},
sampling={"model_config": {"extra": "allow"}},
roots={"listChanged": True},
)
assert caps.experimental == {"feature": {"enabled": True}}
assert caps.sampling is not None
assert caps.roots.listChanged is True # type: ignore
def test_server_capabilities(self):
"""Test ServerCapabilities creation."""
caps = ServerCapabilities(
tools={"listChanged": True},
resources={"subscribe": True, "listChanged": False},
prompts={"listChanged": True},
logging={},
completions={},
)
assert caps.tools.listChanged is True # type: ignore
assert caps.resources.subscribe is True # type: ignore
assert caps.resources.listChanged is False # type: ignore
class TestInitialization:
"""Test initialization request/response types."""
def test_initialize_request(self):
"""Test InitializeRequest creation."""
client_info = Implementation(name="test-client", version="1.0.0")
capabilities = ClientCapabilities()
params = InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
)
request = InitializeRequest(params=params)
assert request.method == "initialize"
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
assert request.params.clientInfo.name == "test-client"
def test_initialize_result(self):
"""Test InitializeResult creation."""
server_info = Implementation(name="test-server", version="1.0.0")
capabilities = ServerCapabilities()
result = InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=server_info,
instructions="Welcome to test server",
)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
assert result.instructions == "Welcome to test server"
class TestTools:
"""Test tool-related types."""
def test_tool_creation(self):
"""Test Tool creation with all fields."""
tool = Tool(
name="test_tool",
title="Test Tool",
description="A tool for testing",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
annotations=ToolAnnotations(
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
),
)
assert tool.name == "test_tool"
assert tool.title == "Test Tool"
assert tool.description == "A tool for testing"
assert tool.inputSchema["properties"]["input"]["type"] == "string"
assert tool.annotations.idempotentHint is True
def test_call_tool_request(self):
"""Test CallToolRequest creation."""
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
request = CallToolRequest(params=params)
assert request.method == "tools/call"
assert request.params.name == "test_tool"
assert request.params.arguments == {"input": "test value"}
def test_call_tool_result(self):
"""Test CallToolResult creation."""
result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
structuredContent={"status": "success", "data": "test"},
isError=False,
)
assert len(result.content) == 1
assert result.content[0].text == "Tool executed successfully" # type: ignore
assert result.structuredContent == {"status": "success", "data": "test"}
assert result.isError is False
def test_list_tools_request(self):
"""Test ListToolsRequest creation."""
request = ListToolsRequest()
assert request.method == "tools/list"
def test_list_tools_result(self):
"""Test ListToolsResult creation."""
tool1 = Tool(name="tool1", inputSchema={})
tool2 = Tool(name="tool2", inputSchema={})
result = ListToolsResult(tools=[tool1, tool2])
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool2"
class TestContent:
"""Test content types."""
def test_text_content(self):
"""Test TextContent creation."""
annotations = Annotations(audience=["user"], priority=0.8)
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
assert content.type == "text"
assert content.text == "Hello, world!"
assert content.annotations is not None
assert content.annotations.priority == 0.8
def test_image_content(self):
"""Test ImageContent creation."""
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
assert content.type == "image"
assert content.data == "base64encodeddata"
assert content.mimeType == "image/png"
class TestOAuth:
"""Test OAuth-related types."""
def test_oauth_client_metadata(self):
"""Test OAuthClientMetadata creation."""
metadata = OAuthClientMetadata(
client_name="Test Client",
redirect_uris=["https://example.com/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
client_uri="https://example.com",
scope="read write",
)
assert metadata.client_name == "Test Client"
assert len(metadata.redirect_uris) == 1
assert "authorization_code" in metadata.grant_types
def test_oauth_client_information(self):
"""Test OAuthClientInformation creation."""
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
assert info.client_id == "test-client-id"
assert info.client_secret == "test-secret"
def test_oauth_client_information_without_secret(self):
"""Test OAuthClientInformation without secret."""
info = OAuthClientInformation(client_id="public-client")
assert info.client_id == "public-client"
assert info.client_secret is None
def test_oauth_tokens(self):
"""Test OAuthTokens creation."""
tokens = OAuthTokens(
access_token="access-token-123",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token-456",
scope="read write",
)
assert tokens.access_token == "access-token-123"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "refresh-token-456"
assert tokens.scope == "read write"
def test_oauth_metadata(self):
"""Test OAuthMetadata creation."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code", "token"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["plain", "S256"],
)
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert "code" in metadata.response_types_supported
assert "S256" in metadata.code_challenge_methods_supported
class TestNotifications:
"""Test notification types."""
def test_progress_notification(self):
"""Test ProgressNotification creation."""
params = ProgressNotificationParams(
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
)
notification = ProgressNotification(params=params)
assert notification.method == "notifications/progress"
assert notification.params.progressToken == "progress-123"
assert notification.params.progress == 50.0
assert notification.params.total == 100.0
assert notification.params.message == "Processing... 50%"
def test_ping_request(self):
"""Test PingRequest creation."""
request = PingRequest()
assert request.method == "ping"
assert request.params is None
class TestCompletion:
"""Test completion-related types."""
def test_completion_context(self):
"""Test CompletionContext creation."""
context = CompletionContext(arguments={"template_var": "value"})
assert context.arguments == {"template_var": "value"}
def test_resource_template_reference(self):
"""Test ResourceTemplateReference creation."""
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
assert ref.type == "ref/resource"
assert ref.uri == "file:///path/to/{filename}"
def test_prompt_reference(self):
"""Test PromptReference creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
assert ref.type == "ref/prompt"
assert ref.name == "test_prompt"
def test_complete_request(self):
"""Test CompleteRequest creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
arg = CompletionArgument(name="arg1", value="val")
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
request = CompleteRequest(params=params)
assert request.method == "completion/complete"
assert request.params.ref.name == "test_prompt" # type: ignore
assert request.params.argument.name == "arg1"
def test_complete_result(self):
"""Test CompleteResult creation."""
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
result = CompleteResult(completion=completion)
assert len(result.completion.values) == 3
assert result.completion.total == 10
assert result.completion.hasMore is True
class TestValidation:
"""Test validation of various types."""
def test_invalid_jsonrpc_version(self):
"""Test invalid JSON-RPC version validation."""
with pytest.raises(ValidationError):
JSONRPCRequest(
jsonrpc="1.0", # Invalid version
id=1,
method="test",
)
def test_tool_annotations_validation(self):
"""Test ToolAnnotations with invalid values."""
# Valid annotations
annotations = ToolAnnotations(
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
)
assert annotations.title == "Test"
def test_extra_fields_allowed(self):
"""Test that extra fields are allowed in models."""
# Most models should allow extra fields
tool = Tool(
name="test",
inputSchema={},
customField="allowed", # type: ignore
)
assert tool.customField == "allowed" # type: ignore
def test_result_meta_alias(self):
"""Test Result model with _meta alias."""
# Create with the field name (not alias)
result = Result(_meta={"key": "value"})
# Verify the field is set correctly
assert result.meta == {"key": "value"}
# Dump with alias
dumped = result.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] == {"key": "value"}

View File

@@ -0,0 +1,355 @@
"""Unit tests for MCP utils module."""
import json
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import httpx
import httpx_sse
import pytest
from core.mcp.utils import (
STATUS_FORCELIST,
create_mcp_error_response,
create_ssrf_proxy_mcp_http_client,
ssrf_proxy_sse_connect,
)
class TestConstants:
"""Test module constants."""
def test_status_forcelist(self):
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
assert 429 in STATUS_FORCELIST # Too Many Requests
assert 500 in STATUS_FORCELIST # Internal Server Error
assert 502 in STATUS_FORCELIST # Bad Gateway
assert 503 in STATUS_FORCELIST # Service Unavailable
assert 504 in STATUS_FORCELIST # Gateway Timeout
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client(
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
)
assert isinstance(client, httpx.Client)
assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 30.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
headers = {"X-Custom-Header": "value"}
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
assert isinstance(client, httpx.Client)
assert client.headers["X-Custom-Header"] == "value"
assert client.timeout.connect == 5.0
assert client.timeout.read == 10.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
# httpx.Client adds default headers, so we just check it's a Headers object
assert isinstance(client.headers, httpx.Headers)
# When no timeout is provided, httpx uses its default timeout
assert client.timeout is not None
# Clean up
client.close()
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call with provided client
result = ssrf_proxy_sse_connect(
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
)
# Verify client creation was not called
mock_create_client.assert_not_called()
# Verify connect_sse was called correctly
mock_connect_sse.assert_called_once_with(
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call without client
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
# Verify client was created
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["headers"] == {"X-Custom": "value"}
timeout = call_args[1]["timeout"]
# httpx.Timeout object has these attributes
assert isinstance(timeout, httpx.Timeout)
assert timeout.connect == 10.0
assert timeout.read == 60.0
assert timeout.write == 30.0
# Verify connect_sse was called
mock_connect_sse.assert_called_once_with(
mock_client,
"GET", # Default method
"http://example.com/sse",
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
# Call with custom timeout
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
# Verify client was created with custom timeout
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["timeout"] == custom_timeout
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse")
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
# Verify client was NOT cleaned up (because it was provided)
mock_client.close.assert_not_called()
class TestCreateMCPErrorResponse:
"""Test create_mcp_error_response function."""
def test_create_error_response_basic(self):
"""Test creating basic error response."""
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
# Generator should yield bytes
assert isinstance(generator, Generator)
# Get the response
response_bytes = next(generator)
assert isinstance(response_bytes, bytes)
# Parse the response
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["jsonrpc"] == "2.0"
assert response_json["id"] == "req-123"
assert response_json["error"]["code"] == -32600
assert response_json["error"]["message"] == "Invalid Request"
assert response_json["error"]["data"] is None
# Generator should be exhausted
with pytest.raises(StopIteration):
next(generator)
def test_create_error_response_with_data(self):
"""Test creating error response with additional data."""
error_data = {"field": "username", "reason": "required"}
generator = create_mcp_error_response(
request_id=456, # Numeric ID
code=-32602,
message="Invalid params",
data=error_data,
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["id"] == 456
assert response_json["error"]["code"] == -32602
assert response_json["error"]["message"] == "Invalid params"
assert response_json["error"]["data"] == error_data
def test_create_error_response_without_request_id(self):
"""Test creating error response without request ID."""
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
# Should default to ID 1
assert response_json["id"] == 1
assert response_json["error"]["code"] == -32700
assert response_json["error"]["message"] == "Parse error"
def test_create_error_response_with_complex_data(self):
"""Test creating error response with complex error data."""
complex_data = {
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
"timestamp": "2024-01-01T00:00:00Z",
}
generator = create_mcp_error_response(
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["error"]["data"] == complex_data
assert len(response_json["error"]["data"]["errors"]) == 2
def test_create_error_response_encoding(self):
"""Test error response with non-ASCII characters."""
generator = create_mcp_error_response(
request_id="unicode-req",
code=-32603,
message="内部错误", # Chinese characters
data={"details": "エラー詳細"}, # Japanese characters
)
response_bytes = next(generator)
# Should be valid UTF-8
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["error"]["message"] == "内部错误"
assert response_json["error"]["data"]["details"] == "エラー詳細"
def test_create_error_response_yields_once(self):
"""Test that error response generator yields exactly once."""
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
# First yield should work
first_yield = next(generator)
assert isinstance(first_yield, bytes)
# Second yield should raise StopIteration
with pytest.raises(StopIteration):
next(generator)
# Subsequent calls should also raise
with pytest.raises(StopIteration):
next(generator)

View File

@@ -180,6 +180,25 @@ class TestMCPToolTransform:
# Set tools data with null description
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=True
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
@@ -198,6 +217,27 @@ class TestMCPToolTransform:
# Set tools data with description
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
"original_headers": {"Authorization": "Bearer secret-token"},
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=False
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
@@ -205,8 +245,9 @@ class TestMCPToolTransform:
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
assert result.server_identifier == "server-identifier-456"
assert result.timeout == 30
assert result.sse_read_timeout == 300
assert result.configuration is not None
assert result.configuration.timeout == 30
assert result.configuration.sse_read_timeout == 300
assert result.original_headers == {"Authorization": "Bearer secret-token"}
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "Tool description"