diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 10e3082aa3..9470d39f41 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass from typing import Annotated, Any, Literal -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, model_validator from pydantic.networks import AnyUrl, UrlConstraints """ @@ -173,7 +173,21 @@ class JSONRPCError(BaseModel): class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): - pass + @model_validator(mode="before") + @classmethod + def _select_message_type( + cls, value: Any + ) -> JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError | Any: + if isinstance(value, dict): + if "result" in value: + return JSONRPCResponse.model_validate(value) + if "error" in value: + return JSONRPCError.model_validate(value) + if "method" in value: + if "id" in value: + return JSONRPCRequest.model_validate(value) + return JSONRPCNotification.model_validate(value) + return value class EmptyResult(Result): diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index e6eeb6cd59..ea0ff7395d 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -34,6 +34,17 @@ def test_sse_message_id_coercion(): assert msg.root.jsonrpc == expected.root.jsonrpc +def test_sse_message_without_id_stays_notification(): + """Test that method messages without an ID still parse as notifications.""" + json_message = '{"jsonrpc": "2.0", "method": "ping", "params": null}' + + msg = types.JSONRPCMessage.model_validate_json(json_message) + + assert isinstance(msg.root, types.JSONRPCNotification) + assert msg.root.method == "ping" + assert msg.root.jsonrpc == "2.0" + + class MockSSEClient: """Mock SSE client for testing."""