From a1f752dc87a92c4eccc29ca17b918f07ef55a81b Mon Sep 17 00:00:00 2001 From: yyh Date: Fri, 15 May 2026 12:30:49 +0800 Subject: [PATCH] Revert "fix(auth): persist oauth init preferences through callback" This reverts commit 73a667f348b4eb2a32d9af3ef48b29a62636683c. --- api/controllers/console/auth/oauth.py | 92 +++------------- .../console/auth/test_oauth_timezone.py | 101 +----------------- 2 files changed, 16 insertions(+), 177 deletions(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 0fbefab30e..2254fa4981 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -13,11 +13,8 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.helper import timezone as validate_timezone_string -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthState, OAuthUserInfo, decode_oauth_state, encode_oauth_state +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state from libs.token import ( - _cookie_domain, - _real_cookie_name, - is_secure, set_access_token_to_cookie, set_csrf_token_to_cookie, set_refresh_token_to_cookie, @@ -33,50 +30,6 @@ from .. import console_ns logger = logging.getLogger(__name__) -OAUTH_INIT_STATE_COOKIE_NAME = "oauth_init_state" -OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS = 10 * 60 - - -def _clear_oauth_init_state_cookie(response) -> None: - response.set_cookie( - _real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME), - "", - expires=0, - httponly=True, - domain=_cookie_domain(), - secure=is_secure(), - samesite="Lax", - path="/", - ) - - -def _set_oauth_init_state_cookie(response, timezone: str | None, language: str | None) -> None: - state = encode_oauth_state(timezone=timezone, language=language) - if not state: - _clear_oauth_init_state_cookie(response) - return - - response.set_cookie( - _real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME), - value=state, - httponly=True, - domain=_cookie_domain(), - secure=is_secure(), - samesite="Lax", - max_age=OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS, - path="/", - ) - - -def _get_oauth_init_state_from_cookie() -> OAuthState: - return decode_oauth_state(request.cookies.get(_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME))) - - -def _redirect_with_cleared_oauth_init_state(location: str): - response = redirect(location) - _clear_oauth_init_state_cookie(response) - return response - def get_oauth_providers(): with current_app.app_context(): @@ -150,9 +103,7 @@ class OAuthLogin(Resource): timezone=timezone, language=language, ) - response = redirect(auth_url) - _set_oauth_init_state_cookie(response, timezone=timezone, language=language) - return response + return redirect(auth_url) @console_ns.route("/oauth/authorize/") @@ -178,10 +129,9 @@ class OAuthCallback(Resource): code = request.args.get("code") state = request.args.get("state") oauth_state = decode_oauth_state(state) - cookie_state = _get_oauth_init_state_from_cookie() invite_token = oauth_state.get("invite_token") - timezone = _validated_timezone(oauth_state.get("timezone") or cookie_state.get("timezone")) - language = _validated_language(oauth_state.get("language") or cookie_state.get("language")) + timezone = _validated_timezone(oauth_state.get("timezone")) + language = _validated_language(oauth_state.get("language")) if not code: return {"error": "Authorization code is required"}, 400 @@ -197,9 +147,7 @@ class OAuthCallback(Resource): return {"error": "OAuth process failed"}, 400 except ValueError as e: logger.warning("OAuth error with %s", provider, exc_info=True) - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}" - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}") if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService.get_invitation_by_token(token=invite_token) @@ -209,35 +157,25 @@ class OAuthCallback(Resource): invitation_email.lower() if isinstance(invitation_email, str) else invitation_email ) if invitation_email_normalized != user_info.email.lower(): - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token." - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}" - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language) except AccountNotFoundError: - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found." - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): - return _redirect_with_cleared_oauth_init_state( + return redirect( f"{dify_config.CONSOLE_WEB_URL}/signin" "?message=Workspace not found, please contact system admin to invite you to join in a workspace." ) except AccountRegisterError as e: - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}" - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") # Check account status if account.status == AccountStatus.BANNED: - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned." - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") if account.status == AccountStatus.PENDING: account.status = AccountStatus.ACTIVE @@ -247,11 +185,9 @@ class OAuthCallback(Resource): try: TenantService.create_owner_tenant_if_not_exist(account) except Unauthorized: - return _redirect_with_cleared_oauth_init_state( - f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found." - ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") except WorkSpaceNotAllowedCreateError: - return _redirect_with_cleared_oauth_init_state( + return redirect( f"{dify_config.CONSOLE_WEB_URL}/signin" "?message=Workspace not found, please contact system admin to invite you to join in a workspace." ) @@ -264,7 +200,7 @@ class OAuthCallback(Resource): base_url = dify_config.CONSOLE_WEB_URL query_char = "&" if "?" in base_url else "?" target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}" - response = _redirect_with_cleared_oauth_init_state(target_url) + response = redirect(target_url) set_access_token_to_cookie(request, response, token_pair.access_token) set_refresh_token_to_cookie(request, response, token_pair.refresh_token) diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py index 88f7ede01f..36c707dbf9 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py @@ -1,12 +1,10 @@ -from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from flask import Flask -from controllers.console.auth.oauth import OAUTH_INIT_STATE_COOKIE_NAME, OAuthCallback, OAuthLogin, _generate_account -from libs.oauth import OAuthUserInfo, encode_oauth_state -from models.account import AccountStatus +from controllers.console.auth.oauth import OAuthLogin, _generate_account +from libs.oauth import OAuthUserInfo from services.errors.account import AccountRegisterError @@ -39,101 +37,6 @@ def test_oauth_login_passes_language_and_timezone_to_authorization_url( mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...") -@patch("controllers.console.auth.oauth.get_oauth_providers") -def test_oauth_login_sets_init_state_cookie( - mock_get_oauth_providers, - app: Flask, -): - oauth_provider = MagicMock() - oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..." - mock_get_oauth_providers.return_value = {"github": oauth_provider} - - with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"): - response = OAuthLogin().get("github") - - set_cookie_headers = response.headers.getlist("Set-Cookie") - assert any(header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=") for header in set_cookie_headers) - assert any("HttpOnly" in header and "SameSite=Lax" in header for header in set_cookie_headers) - - -@pytest.mark.parametrize( - ("state", "cookie_state", "expected_timezone", "expected_language"), - [ - (None, encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"), "Asia/Shanghai", "zh-Hans"), - ( - encode_oauth_state(timezone="Europe/Paris", language="fr-FR"), - encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"), - "Europe/Paris", - "fr-FR", - ), - ], -) -@patch("controllers.console.auth.oauth.set_csrf_token_to_cookie") -@patch("controllers.console.auth.oauth.set_refresh_token_to_cookie") -@patch("controllers.console.auth.oauth.set_access_token_to_cookie") -@patch("controllers.console.auth.oauth.AccountService") -@patch("controllers.console.auth.oauth.TenantService") -@patch("controllers.console.auth.oauth._generate_account") -@patch("controllers.console.auth.oauth.get_oauth_providers") -def test_oauth_callback_uses_state_then_cookie_for_registration_preferences( - mock_get_oauth_providers, - mock_generate_account, - mock_tenant_service, - mock_account_service, - mock_set_access_token, - mock_set_refresh_token, - mock_set_csrf_token, - state, - cookie_state, - expected_timezone, - expected_language, - app: Flask, -): - oauth_provider = MagicMock() - oauth_provider.get_access_token.return_value = "access-token" - oauth_provider.get_user_info.return_value = OAuthUserInfo( - id="github-123", - name="Test User", - email="user@example.com", - ) - mock_get_oauth_providers.return_value = {"github": oauth_provider} - - account = MagicMock() - account.status = AccountStatus.ACTIVE - mock_generate_account.return_value = (account, True) - mock_account_service.login.return_value = SimpleNamespace( - access_token="access-token", - refresh_token="refresh-token", - csrf_token="csrf-token", - ) - - query = "code=test-code" - if state: - query = f"{query}&state={state}" - - headers = {"Cookie": f"{OAUTH_INIT_STATE_COOKIE_NAME}={cookie_state}"} - with ( - patch("controllers.console.auth.oauth.dify_config.CONSOLE_WEB_URL", "http://localhost:3000"), - app.test_request_context(f"/oauth/authorize/github?{query}", headers=headers), - ): - response = OAuthCallback().get("github") - - mock_generate_account.assert_called_once_with( - "github", - oauth_provider.get_user_info.return_value, - timezone=expected_timezone, - language=expected_language, - ) - mock_tenant_service.create_owner_tenant_if_not_exist.assert_called_once_with(account) - mock_account_service.login.assert_called_once() - assert response.status_code == 302 - assert response.headers["Location"] == "http://localhost:3000?oauth_new_user=true" - assert any( - header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=;") and "Expires=Thu, 01 Jan 1970 00:00:00 GMT" in header - for header in response.headers.getlist("Set-Cookie") - ) - - @patch("controllers.console.auth.oauth.AccountService.link_account_integrate") @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.FeatureService")