mirror of
https://github.com/langgenius/dify.git
synced 2026-05-15 04:00:46 -04:00
Revert "fix(auth): persist oauth init preferences through callback"
This reverts commit 73a667f348.
This commit is contained in:
@@ -13,11 +13,8 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthState, OAuthUserInfo, decode_oauth_state, encode_oauth_state
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
from libs.token import (
|
||||
_cookie_domain,
|
||||
_real_cookie_name,
|
||||
is_secure,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
@@ -33,50 +30,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OAUTH_INIT_STATE_COOKIE_NAME = "oauth_init_state"
|
||||
OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS = 10 * 60
|
||||
|
||||
|
||||
def _clear_oauth_init_state_cookie(response) -> None:
|
||||
response.set_cookie(
|
||||
_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME),
|
||||
"",
|
||||
expires=0,
|
||||
httponly=True,
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def _set_oauth_init_state_cookie(response, timezone: str | None, language: str | None) -> None:
|
||||
state = encode_oauth_state(timezone=timezone, language=language)
|
||||
if not state:
|
||||
_clear_oauth_init_state_cookie(response)
|
||||
return
|
||||
|
||||
response.set_cookie(
|
||||
_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME),
|
||||
value=state,
|
||||
httponly=True,
|
||||
domain=_cookie_domain(),
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
max_age=OAUTH_INIT_STATE_COOKIE_MAX_AGE_SECONDS,
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def _get_oauth_init_state_from_cookie() -> OAuthState:
|
||||
return decode_oauth_state(request.cookies.get(_real_cookie_name(OAUTH_INIT_STATE_COOKIE_NAME)))
|
||||
|
||||
|
||||
def _redirect_with_cleared_oauth_init_state(location: str):
|
||||
response = redirect(location)
|
||||
_clear_oauth_init_state_cookie(response)
|
||||
return response
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
@@ -150,9 +103,7 @@ class OAuthLogin(Resource):
|
||||
timezone=timezone,
|
||||
language=language,
|
||||
)
|
||||
response = redirect(auth_url)
|
||||
_set_oauth_init_state_cookie(response, timezone=timezone, language=language)
|
||||
return response
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/authorize/<provider>")
|
||||
@@ -178,10 +129,9 @@ class OAuthCallback(Resource):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
oauth_state = decode_oauth_state(state)
|
||||
cookie_state = _get_oauth_init_state_from_cookie()
|
||||
invite_token = oauth_state.get("invite_token")
|
||||
timezone = _validated_timezone(oauth_state.get("timezone") or cookie_state.get("timezone"))
|
||||
language = _validated_language(oauth_state.get("language") or cookie_state.get("language"))
|
||||
timezone = _validated_timezone(oauth_state.get("timezone"))
|
||||
language = _validated_language(oauth_state.get("language"))
|
||||
|
||||
if not code:
|
||||
return {"error": "Authorization code is required"}, 400
|
||||
@@ -197,9 +147,7 @@ class OAuthCallback(Resource):
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}"
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
@@ -209,35 +157,25 @@ class OAuthCallback(Resource):
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token."
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}"
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
|
||||
except AccountNotFoundError:
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found."
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
except AccountRegisterError as e:
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}"
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED:
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned."
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
@@ -247,11 +185,9 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
except Unauthorized:
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found."
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return _redirect_with_cleared_oauth_init_state(
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
@@ -264,7 +200,7 @@ class OAuthCallback(Resource):
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = _redirect_with_cleared_oauth_init_state(target_url)
|
||||
response = redirect(target_url)
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import OAUTH_INIT_STATE_COOKIE_NAME, OAuthCallback, OAuthLogin, _generate_account
|
||||
from libs.oauth import OAuthUserInfo, encode_oauth_state
|
||||
from models.account import AccountStatus
|
||||
from controllers.console.auth.oauth import OAuthLogin, _generate_account
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
@@ -39,101 +37,6 @@ def test_oauth_login_passes_language_and_timezone_to_authorization_url(
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_oauth_login_sets_init_state_cookie(
|
||||
mock_get_oauth_providers,
|
||||
app: Flask,
|
||||
):
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
|
||||
mock_get_oauth_providers.return_value = {"github": oauth_provider}
|
||||
|
||||
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
|
||||
response = OAuthLogin().get("github")
|
||||
|
||||
set_cookie_headers = response.headers.getlist("Set-Cookie")
|
||||
assert any(header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=") for header in set_cookie_headers)
|
||||
assert any("HttpOnly" in header and "SameSite=Lax" in header for header in set_cookie_headers)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("state", "cookie_state", "expected_timezone", "expected_language"),
|
||||
[
|
||||
(None, encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"), "Asia/Shanghai", "zh-Hans"),
|
||||
(
|
||||
encode_oauth_state(timezone="Europe/Paris", language="fr-FR"),
|
||||
encode_oauth_state(timezone="Asia/Shanghai", language="zh-Hans"),
|
||||
"Europe/Paris",
|
||||
"fr-FR",
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.set_csrf_token_to_cookie")
|
||||
@patch("controllers.console.auth.oauth.set_refresh_token_to_cookie")
|
||||
@patch("controllers.console.auth.oauth.set_access_token_to_cookie")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_oauth_callback_uses_state_then_cookie_for_registration_preferences(
|
||||
mock_get_oauth_providers,
|
||||
mock_generate_account,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_set_access_token,
|
||||
mock_set_refresh_token,
|
||||
mock_set_csrf_token,
|
||||
state,
|
||||
cookie_state,
|
||||
expected_timezone,
|
||||
expected_language,
|
||||
app: Flask,
|
||||
):
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_access_token.return_value = "access-token"
|
||||
oauth_provider.get_user_info.return_value = OAuthUserInfo(
|
||||
id="github-123",
|
||||
name="Test User",
|
||||
email="user@example.com",
|
||||
)
|
||||
mock_get_oauth_providers.return_value = {"github": oauth_provider}
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
mock_generate_account.return_value = (account, True)
|
||||
mock_account_service.login.return_value = SimpleNamespace(
|
||||
access_token="access-token",
|
||||
refresh_token="refresh-token",
|
||||
csrf_token="csrf-token",
|
||||
)
|
||||
|
||||
query = "code=test-code"
|
||||
if state:
|
||||
query = f"{query}&state={state}"
|
||||
|
||||
headers = {"Cookie": f"{OAUTH_INIT_STATE_COOKIE_NAME}={cookie_state}"}
|
||||
with (
|
||||
patch("controllers.console.auth.oauth.dify_config.CONSOLE_WEB_URL", "http://localhost:3000"),
|
||||
app.test_request_context(f"/oauth/authorize/github?{query}", headers=headers),
|
||||
):
|
||||
response = OAuthCallback().get("github")
|
||||
|
||||
mock_generate_account.assert_called_once_with(
|
||||
"github",
|
||||
oauth_provider.get_user_info.return_value,
|
||||
timezone=expected_timezone,
|
||||
language=expected_language,
|
||||
)
|
||||
mock_tenant_service.create_owner_tenant_if_not_exist.assert_called_once_with(account)
|
||||
mock_account_service.login.assert_called_once()
|
||||
assert response.status_code == 302
|
||||
assert response.headers["Location"] == "http://localhost:3000?oauth_new_user=true"
|
||||
assert any(
|
||||
header.startswith(f"{OAUTH_INIT_STATE_COOKIE_NAME}=;") and "Expires=Thu, 01 Jan 1970 00:00:00 GMT" in header
|
||||
for header in response.headers.getlist("Set-Cookie")
|
||||
)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
|
||||
Reference in New Issue
Block a user