diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 9ee8469892..c034662cf4 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -3,6 +3,7 @@ import json import logging from collections.abc import Callable, Generator from typing import Any, cast +from urllib.parse import unquote import httpx from pydantic import BaseModel @@ -53,6 +54,9 @@ else: logger = logging.getLogger(__name__) +PLUGIN_DAEMON_MAX_PATH_LENGTH = 4096 +PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH = 8 + _httpx_client: httpx.Client = get_pooled_http_client( "plugin_daemon", lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False), @@ -103,6 +107,20 @@ class BasePluginClient: params: dict[str, Any] | None, files: dict[str, Any] | None, ) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]: + if len(path) > PLUGIN_DAEMON_MAX_PATH_LENGTH: + raise ValueError(f"Invalid plugin daemon path: path length exceeds {PLUGIN_DAEMON_MAX_PATH_LENGTH}") + + decoded_path = path + for _ in range(PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH): + next_decoded_path = unquote(decoded_path) + if next_decoded_path == decoded_path: + break + decoded_path = next_decoded_path + else: + raise ValueError("Invalid plugin daemon path: path is too deeply encoded") + + if any(seg == ".." for seg in decoded_path.split("/")): + raise ValueError(f"Invalid plugin daemon path: traversal sequence detected in {path!r}") url = plugin_daemon_inner_api_baseurl / path prepared_headers = dict(headers or {}) prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py index b154f056ca..bea808516d 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -1,11 +1,12 @@ import json +from urllib.parse import quote import pytest from pytest_mock import MockerFixture from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonInnerError -from core.plugin.impl.base import BasePluginClient +from core.plugin.impl.base import PLUGIN_DAEMON_MAX_PATH_LENGTH, BasePluginClient from core.trigger.errors import ( EventIgnoreError, TriggerInvokeError, @@ -67,6 +68,36 @@ class TestBasePluginClientImpl: assert result == ["hello", "world"] assert stream_mock.call_args.kwargs["data"] == {"k": "v"} + @pytest.mark.parametrize( + "path", + [ + "plugin/tenant/%252e%252e%252ftarget", + "plugin/tenant/%2e%2e%252ftarget", + ], + ) + def test_prepare_request_rejects_encoded_traversal_with_encoded_separator(self, path: str): + client = BasePluginClient() + + with pytest.raises(ValueError, match="traversal sequence detected"): + client._prepare_request(path, None, None, None, None) + + def test_prepare_request_rejects_path_exceeding_max_length(self): + client = BasePluginClient() + path = "a" * (PLUGIN_DAEMON_MAX_PATH_LENGTH + 1) + + with pytest.raises(ValueError, match="path length exceeds"): + client._prepare_request(path, None, None, None, None) + + def test_prepare_request_rejects_excessively_encoded_path(self): + client = BasePluginClient() + segment = "..%2Ftarget" + for _ in range(9): + segment = quote(segment, safe="") + path = f"plugin/tenant/{segment}" + + with pytest.raises(ValueError, match="too deeply encoded"): + client._prepare_request(path, None, None, None, None) + def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker: MockerFixture): client = BasePluginClient() mocker.patch.object(client, "_request", side_effect=RuntimeError("boom"))