Revert "fix(auth): persist oauth init preferences through callback"

This reverts commit 73a667f348.
This commit is contained in:
yyh
2026-05-15 12:30:49 +08:00
parent 73a667f348
commit a1f752dc87
2 changed files with 16 additions and 177 deletions

View File

@@ -13,11 +13,8 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthState, OAuthUserInfo, decode_oauth_state, encode_oauth_state
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
from libs.token import (
_cookie_domain,
_real_cookie_name,
is_secure,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@@ -33,50 +30,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
OAUTH_INIT_STATE_COOKIE_NAME = "oauth_init_state"
OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS = 10 * 60
def _clear_oauth_init_state_cookie(response) -> None:
response.set_cookie(
_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME),
"",
expires=0,
httponly=True,
domain=_cookie_domain(),
secure=is_secure(),
samesite="Lax",
path="/",
)
def _set_oauth_init_state_cookie(response, timezone: str | None, language: str | None) -> None:
state = encode_oauth_state(timezone=timezone, language=language)
if not state:
_clear_oauth_init_state_cookie(response)
return
response.set_cookie(
_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME),
value=state,
httponly=True,
domain=_cookie_domain(),
secure=is_secure(),
samesite="Lax",
max_age=OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS,
path="/",
)
def _get_oauth_init_state_from_cookie() -> OAuthState:
return decode_oauth_state(request.cookies.get(_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME)))
def _redirect_with_cleared_oauth_init_state(location: str):
response = redirect(location)
_clear_oauth_init_state_cookie(response)
return response
def get_oauth_providers():
with current_app.app_context():
@@ -150,9 +103,7 @@ class OAuthLogin(Resource):
timezone=timezone,
language=language,
)
response = redirect(auth_url)
_set_oauth_init_state_cookie(response, timezone=timezone, language=language)
return response
return redirect(auth_url)
@console_ns.route("/oauth/authorize/<provider>")
@@ -178,10 +129,9 @@ class OAuthCallback(Resource):
code = request.args.get("code")
state = request.args.get("state")
oauth_state = decode_oauth_state(state)
cookie_state = _get_oauth_init_state_from_cookie()
invite_token = oauth_state.get("invite_token")
timezone = _validated_timezone(oauth_state.get("timezone") or cookie_state.get("timezone"))
language = _validated_language(oauth_state.get("language") or cookie_state.get("language"))
timezone = _validated_timezone(oauth_state.get("timezone"))
language = _validated_language(oauth_state.get("language"))
if not code:
return {"error": "Authorization code is required"}, 400
@@ -197,9 +147,7 @@ class OAuthCallback(Resource):
return {"error": "OAuth process failed"}, 400
except ValueError as e:
logger.warning("OAuth error with %s", provider, exc_info=True)
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}"
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
@@ -209,35 +157,25 @@ class OAuthCallback(Resource):
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token."
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}"
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
except AccountNotFoundError:
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found."
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
return _redirect_with_cleared_oauth_init_state(
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
except AccountRegisterError as e:
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}"
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED:
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned."
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
@@ -247,11 +185,9 @@ class OAuthCallback(Resource):
try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return _redirect_with_cleared_oauth_init_state(
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found."
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return _redirect_with_cleared_oauth_init_state(
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
@@ -264,7 +200,7 @@ class OAuthCallback(Resource):
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = _redirect_with_cleared_oauth_init_state(target_url)
response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)

View File

@@ -1,12 +1,10 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import OAUTH_INIT_STATE_COOKIE_NAME, OAuthCallback, OAuthLogin, _generate_account
from libs.oauth import OAuthUserInfo, encode_oauth_state
from models.account import AccountStatus
from controllers.console.auth.oauth import OAuthLogin, _generate_account
from libs.oauth import OAuthUserInfo
from services.errors.account import AccountRegisterError
@@ -39,101 +37,6 @@ def test_oauth_login_passes_language_and_timezone_to_authorization_url(
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_oauth_login_sets_init_state_cookie(
mock_get_oauth_providers,
app: Flask,
):
oauth_provider = MagicMock()
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
mock_get_oauth_providers.return_value = {"github": oauth_provider}
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
response = OAuthLogin().get("github")
set_cookie_headers = response.headers.getlist("Set-Cookie")
assert any(header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=") for header in set_cookie_headers)
assert any("HttpOnly" in header and "SameSite=Lax" in header for header in set_cookie_headers)
@pytest.mark.parametrize(
("state", "cookie_state", "expected_timezone", "expected_language"),
[
(None, encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"), "Asia/Shanghai", "zh-Hans"),
(
encode_oauth_state(timezone="Europe/Paris", language="fr-FR"),
encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"),
"Europe/Paris",
"fr-FR",
),
],
)
@patch("controllers.console.auth.oauth.set_csrf_token_to_cookie")
@patch("controllers.console.auth.oauth.set_refresh_token_to_cookie")
@patch("controllers.console.auth.oauth.set_access_token_to_cookie")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_oauth_callback_uses_state_then_cookie_for_registration_preferences(
mock_get_oauth_providers,
mock_generate_account,
mock_tenant_service,
mock_account_service,
mock_set_access_token,
mock_set_refresh_token,
mock_set_csrf_token,
state,
cookie_state,
expected_timezone,
expected_language,
app: Flask,
):
oauth_provider = MagicMock()
oauth_provider.get_access_token.return_value = "access-token"
oauth_provider.get_user_info.return_value = OAuthUserInfo(
id="github-123",
name="Test User",
email="user@example.com",
)
mock_get_oauth_providers.return_value = {"github": oauth_provider}
account = MagicMock()
account.status = AccountStatus.ACTIVE
mock_generate_account.return_value = (account, True)
mock_account_service.login.return_value = SimpleNamespace(
access_token="access-token",
refresh_token="refresh-token",
csrf_token="csrf-token",
)
query = "code=test-code"
if state:
query = f"{query}&state={state}"
headers = {"Cookie": f"{OAUTH_INIT_STATE_COOKIE_NAME}={cookie_state}"}
with (
patch("controllers.console.auth.oauth.dify_config.CONSOLE_WEB_URL", "http://localhost:3000"),
app.test_request_context(f"/oauth/authorize/github?{query}", headers=headers),
):
response = OAuthCallback().get("github")
mock_generate_account.assert_called_once_with(
"github",
oauth_provider.get_user_info.return_value,
timezone=expected_timezone,
language=expected_language,
)
mock_tenant_service.create_owner_tenant_if_not_exist.assert_called_once_with(account)
mock_account_service.login.assert_called_once()
assert response.status_code == 302
assert response.headers["Location"] == "http://localhost:3000?oauth_new_user=true"
assert any(
header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=;") and "Expires=Thu, 01 Jan 1970 00:00:00 GMT" in header
for header in response.headers.getlist("Set-Cookie")
)
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")