From 318a3d03087e1102b856381fc004acb0d781bb5f Mon Sep 17 00:00:00 2001 From: 99 Date: Thu, 2 Apr 2026 17:36:58 +0800 Subject: [PATCH] refactor(api): tighten login and wrapper typing (#34447) --- .../console/app/workflow_draft_variable.py | 4 +- api/controllers/console/app/wraps.py | 54 +++++-- .../rag_pipeline_draft_variable.py | 5 +- api/controllers/service_api/wraps.py | 140 ++++++++---------- api/dify_app.py | 11 +- api/extensions/ext_login.py | 39 ++++- api/libs/login.py | 27 ++-- .../console/test_workspace_account.py | 2 +- .../console/test_workspace_members.py | 2 +- .../unit_tests/extensions/test_ext_login.py | 17 +++ api/tests/unit_tests/libs/test_login.py | 49 +++++- 11 files changed, 229 insertions(+), 121 deletions(-) create mode 100644 api/tests/unit_tests/extensions/test_ext_login.py diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 366f145360..f6d076320c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model( ) -def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]: +def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]: @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @wraps(f) - def wrapper(*args: Any, **kwargs: Any): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: return f(*args, **kwargs) return wrapper diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index bd6f019eac..c9cf08072a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Any +from typing import overload from sqlalchemy import select @@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None: return app_model -def get_app_model( - view: Callable[..., Any] | None = None, +@overload +def get_app_model[**P, R]( + view: Callable[P, R], *, mode: AppMode | list[AppMode] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: +) -> Callable[P, R]: ... + + +@overload +def get_app_model[**P, R]( + view: None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def get_app_model[**P, R]( + view: Callable[P, R] | None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) - def decorated_view(*args: Any, **kwargs: Any): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") @@ -68,14 +84,30 @@ def get_app_model( return decorator(view) -def get_app_model_with_trial( - view: Callable[..., Any] | None = None, +@overload +def get_app_model_with_trial[**P, R]( + view: Callable[P, R], *, mode: AppMode | list[AppMode] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: +) -> Callable[P, R]: ... + + +@overload +def get_app_model_with_trial[**P, R]( + view: None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def get_app_model_with_trial[**P, R]( + view: Callable[P, R] | None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) - def decorated_view(*args: Any, **kwargs: Any): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index d635dcb530..93feec0019 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Callable from typing import Any, NoReturn from flask import Response, request @@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel): register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) -def _api_prerequisite(f): +def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -70,7 +71,7 @@ def _api_prerequisite(f): @login_required @account_initialization_required @get_rag_pipeline - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2dd916bb31..b9389ccc47 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,9 +1,10 @@ +import inspect import logging import time from collections.abc import Callable from enum import StrEnum, auto from functools import wraps -from typing import Any, cast, overload +from typing import cast, overload from flask import current_app, request from flask_login import user_logged_in @@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R]( return interceptor -def validate_dataset_token( - view: Callable[..., Any] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(view_func) - def decorated(*args: Any, **kwargs: Any) -> Any: - api_token = validate_and_get_api_token("dataset") +def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]: + positional_parameters = [ + parameter + for parameter in inspect.signature(view).parameters.values() + if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"}) - # get url path dataset_id from positional args or kwargs - # Flask passes URL path parameters as positional arguments - dataset_id = None + @wraps(view) + def decorated(*args: object, **kwargs: object) -> R: + api_token = validate_and_get_api_token("dataset") - # First try to get from kwargs (explicit parameter) - dataset_id = kwargs.get("dataset_id") + # Flask may pass URL path parameters positionally, so inspect both kwargs and args. + dataset_id = kwargs.get("dataset_id") - # If not in kwargs, try to extract from positional args - if not dataset_id and args: - # For class methods: args[0] is self, args[1] is dataset_id (if exists) - # Check if first arg is likely a class instance (has __dict__ or __class__) - if len(args) > 1 and hasattr(args[0], "__dict__"): - # This is a class method, dataset_id should be in args[1] - potential_id = args[1] - # Validate it's a string-like UUID, not another object - try: - # Try to convert to string and check if it's a valid UUID format - str_id = str(potential_id) - # Basic check: UUIDs are 36 chars with hyphens - if len(str_id) == 36 and str_id.count("-") == 4: - dataset_id = str_id - except Exception: - logger.exception("Failed to parse dataset_id from class method args") - elif len(args) > 0: - # Not a class method, check if args[0] looks like a UUID - potential_id = args[0] - try: - str_id = str(potential_id) - if len(str_id) == 36 and str_id.count("-") == 4: - dataset_id = str_id - except Exception: - logger.exception("Failed to parse dataset_id from positional args") + if not dataset_id and args: + potential_id = args[0] + try: + str_id = str(potential_id) + if len(str_id) == 36 and str_id.count("-") == 4: + dataset_id = str_id + except Exception: + logger.exception("Failed to parse dataset_id from positional args") - # Validate dataset if dataset_id is provided - if dataset_id: - dataset_id = str(dataset_id) - dataset = db.session.scalar( - select(Dataset) - .where( - Dataset.id == dataset_id, - Dataset.tenant_id == api_token.tenant_id, - ) - .limit(1) + if dataset_id: + dataset_id = str(dataset_id) + dataset = db.session.scalar( + select(Dataset) + .where( + Dataset.id == dataset_id, + Dataset.tenant_id == api_token.tenant_id, ) - if not dataset: - raise NotFound("Dataset not found.") - if not dataset.enable_api: - raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = db.session.execute( - select(Tenant, TenantAccountJoin) - .where(Tenant.id == api_token.tenant_id) - .where(TenantAccountJoin.tenant_id == Tenant.id) - .where(TenantAccountJoin.role.in_(["owner"])) - .where(Tenant.status == TenantStatus.NORMAL) - ).one_or_none() # TODO: only owner information is required, so only one is returned. - if tenant_account_join: - tenant, ta = tenant_account_join - account = db.session.get(Account, ta.account_id) - # Login admin - if account: - account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore - else: - raise Unauthorized("Tenant owner account does not exist.") + .limit(1) + ) + if not dataset: + raise NotFound("Dataset not found.") + if not dataset.enable_api: + raise Forbidden("Dataset api access is not enabled.") + + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) + ).one_or_none() # TODO: only owner information is required, so only one is returned. + if tenant_account_join: + tenant, ta = tenant_account_join + account = db.session.get(Account, ta.account_id) + # Login admin + if account: + account.current_tenant = tenant + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: - raise Unauthorized("Tenant does not exist.") - if args and isinstance(args[0], Resource): - return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs) + raise Unauthorized("Tenant owner account does not exist.") + else: + raise Unauthorized("Tenant does not exist.") - return view_func(api_token.tenant_id, *args, **kwargs) + if expects_bound_instance: + if not args: + raise TypeError("validate_dataset_token expected a bound resource instance.") + return view(args[0], api_token.tenant_id, *args[1:], **kwargs) - return decorated + return view(api_token.tenant_id, *args, **kwargs) - if view: - return decorator(view) - - # if view is None, it means that the decorator is used without parentheses - # use the decorator as a function for method_decorators - return decorator + return decorated def validate_and_get_api_token(scope: str | None = None): diff --git a/api/dify_app.py b/api/dify_app.py index d6deb8e007..bbe3f33787 100644 --- a/api/dify_app.py +++ b/api/dify_app.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from flask import Flask +if TYPE_CHECKING: + from extensions.ext_login import DifyLoginManager + class DifyApp(Flask): - pass + """Flask application type with Dify-specific extension attributes.""" + + login_manager: DifyLoginManager diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 02e50a90fc..bc59eaca63 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,7 +1,8 @@ import json +from typing import cast import flask_login -from flask import Response, request +from flask import Request, Response, request from flask_login import user_loaded_from_request, user_logged_in from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized @@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser from services.account_service import AccountService -login_manager = flask_login.LoginManager() +type LoginUser = Account | EndUser + + +class DifyLoginManager(flask_login.LoginManager): + """Project-specific Flask-Login manager with a stable unauthorized contract. + + Dify registers `unauthorized_handler` below to always return a JSON `Response`. + Overriding this method lets callers rely on that narrower return type instead of + Flask-Login's broader callback contract. + """ + + def unauthorized(self) -> Response: + """Return the registered unauthorized handler result as a Flask `Response`.""" + return cast(Response, super().unauthorized()) + + def load_user_from_request_context(self) -> None: + """Populate Flask-Login's request-local user cache for the current request.""" + self._load_user() + + +login_manager = DifyLoginManager() # Flask-Login configuration @login_manager.request_loader -def load_user_from_request(request_from_flask_login): +def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None: """Load user based on the request.""" + del request_from_flask_login + # Skip authentication for documentation endpoints if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None @@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login): raise NotFound("End user not found.") return end_user + return None + @user_logged_in.connect @user_loaded_from_request.connect -def on_user_logged_in(_sender, user): +def on_user_logged_in(_sender: object, user: LoginUser) -> None: """Called when a user logged in. Note: AccountService.load_logged_in_account will populate user.current_tenant_id @@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user): @login_manager.unauthorized_handler -def unauthorized_handler(): +def unauthorized_handler() -> Response: """Handle unauthorized requests.""" + # Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows + # Flask-Login's callback contract based on this override. return Response( json.dumps({"code": "unauthorized", "message": "Unauthorized."}), status=401, @@ -123,5 +150,5 @@ def unauthorized_handler(): ) -def init_app(app: DifyApp): +def init_app(app: DifyApp) -> None: login_manager.init_app(app) diff --git a/api/libs/login.py b/api/libs/login.py index 68a2050747..067597cb3c 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,19 +2,19 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -from flask import current_app, g, has_request_context, request +from flask import Response, current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS from werkzeug.local import LocalProxy from configs import dify_config +from dify_app import DifyApp +from extensions.ext_login import DifyLoginManager from libs.token import check_csrf_token from models import Account if TYPE_CHECKING: - from flask.typing import ResponseReturnValue - from models.model import EndUser @@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None: return get_current_object() if callable(get_current_object) else user_proxy # type: ignore -def current_account_with_tenant(): +def _get_login_manager() -> DifyLoginManager: + """Return the project login manager with Dify's narrowed unauthorized contract.""" + app = cast(DifyApp, current_app) + return app.login_manager + + +def current_account_with_tenant() -> tuple[Account, str]: """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. @@ -42,7 +48,7 @@ def current_account_with_tenant(): return user, user.current_tenant_id -def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]: +def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are @@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu """ @wraps(func) - def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) user = _resolve_current_user() if user is None or not user.is_authenticated: - return current_app.login_manager.unauthorized() # type: ignore + # `DifyLoginManager` guarantees that the registered unauthorized handler + # is surfaced here as a concrete Flask `Response`. + unauthorized_response: Response = _get_login_manager().unauthorized() + return unauthorized_response g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. @@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu def _get_user() -> EndUser | Account | None: if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() # type: ignore + _get_login_manager().load_user_from_request_context() return g._login_user diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 9afc1c4166..7f9fe9cbf9 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -20,7 +20,7 @@ def app(): app = Flask(__name__) app.config["TESTING"] = True app.config["RESTX_MASK_HEADER"] = "X-Fields" - app.login_manager = SimpleNamespace(_load_user=lambda: None) + app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return app diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 368892b922..239fec8430 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole def app(): flask_app = Flask(__name__) flask_app.config["TESTING"] = True - flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return flask_app diff --git a/api/tests/unit_tests/extensions/test_ext_login.py b/api/tests/unit_tests/extensions/test_ext_login.py new file mode 100644 index 0000000000..64abc19427 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_ext_login.py @@ -0,0 +1,17 @@ +import json + +from flask import Response + +from extensions.ext_login import unauthorized_handler + + +def test_unauthorized_handler_returns_json_response() -> None: + response = unauthorized_handler() + + assert isinstance(response, Response) + assert response.status_code == 401 + assert response.content_type == "application/json" + assert json.loads(response.get_data(as_text=True)) == { + "code": "unauthorized", + "message": "Unauthorized.", + } diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 0c9e73299b..2bf2212844 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -2,11 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from flask import Flask, g -from flask_login import LoginManager, UserMixin +from flask import Flask, Response, g +from flask_login import UserMixin from pytest_mock import MockerFixture import libs.login as login_module +from extensions.ext_login import DifyLoginManager from libs.login import current_user from models.account import Account @@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask: app = Flask(__name__) app.config["TESTING"] = True - login_manager = LoginManager() + login_manager = DifyLoginManager() login_manager.init_app(app) - login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized") + login_manager.unauthorized = mocker.Mock( + name="unauthorized", + return_value=Response("Unauthorized", status=401, content_type="application/json"), + ) @login_manager.user_loader def load_user(_user_id: str): @@ -109,18 +113,43 @@ class TestLoginRequired: resolved_user: MockUser | None, description: str, ): - """Test that missing or unauthenticated users are redirected.""" + """Test that missing or unauthenticated users return the manager response.""" resolve_user = resolve_current_user(resolved_user) with login_app.test_request_context(): result = protected_view() - assert result == "Unauthorized", description + assert result is login_app.login_manager.unauthorized.return_value, description + assert isinstance(result, Response) + assert result.status_code == 401 resolve_user.assert_called_once_with() login_app.login_manager.unauthorized.assert_called_once_with() csrf_check.assert_not_called() + def test_unauthorized_access_propagates_response_object( + self, + login_app: Flask, + protected_view, + csrf_check: MagicMock, + resolve_current_user, + mocker: MockerFixture, + ) -> None: + """Test that unauthorized responses are propagated as Flask Response objects.""" + resolve_user = resolve_current_user(None) + response = Response("Unauthorized", status=401, content_type="application/json") + mocker.patch.object( + login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response) + ) + + with login_app.test_request_context(): + result = protected_view() + + assert result is response + assert isinstance(result, Response) + resolve_user.assert_called_once_with() + csrf_check.assert_not_called() + @pytest.mark.parametrize( ("method", "login_disabled"), [ @@ -168,10 +197,14 @@ class TestGetUser: """Test that _get_user loads user if not already in g.""" mock_user = MockUser("test_user") - def _load_user() -> None: + def load_user_from_request_context() -> None: g._login_user = mock_user - load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user) + load_user = mocker.patch.object( + login_app.login_manager, + "load_user_from_request_context", + side_effect=load_user_from_request_context, + ) with login_app.test_request_context(): user = login_module._get_user()