From a252fbddfa7c666c9f2d5acb986c7543d61ae35e Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Fri, 15 May 2026 16:12:52 +0800 Subject: [PATCH] feat: initialize user timezone and language from browser (#36170) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/auth/email_register.py | 46 ++++-- api/controllers/console/auth/login.py | 12 +- api/controllers/console/auth/oauth.py | 67 ++++++--- api/libs/oauth.py | 72 ++++++++- api/openapi/markdown/console-swagger.md | 3 + api/services/account_service.py | 31 ++-- .../console/auth/test_email_register.py | 113 +++++++++++++- .../controllers/console/auth/test_oauth.py | 133 ++++++++++++++++- .../auth/test_email_register_language.py | 40 +++++ .../console/auth/test_email_verification.py | 23 ++- .../console/auth/test_oauth_timezone.py | 123 ++++++++++++++++ api/tests/unit_tests/libs/test_oauth_base.py | 16 +- .../unit_tests/libs/test_oauth_clients.py | 50 +++++-- .../services/test_account_service.py | 26 +++- web/app/signin/check-code/page.tsx | 9 +- .../components/__tests__/social-auth.spec.tsx | 86 +++++++++++ web/app/signin/components/social-auth.tsx | 14 +- .../invite-settings/__tests__/page.spec.tsx | 139 ++++++++++++++++++ web/app/signin/invite-settings/page.tsx | 19 ++- .../set-password/__tests__/page.spec.tsx | 85 +++++++++++ web/app/signup/set-password/page.tsx | 7 +- web/service/common.ts | 8 +- web/service/use-common.ts | 8 +- web/utils/timezone.ts | 7 + 24 files changed, 1050 insertions(+), 87 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/auth/test_email_register_language.py create mode 100644 api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py create mode 100644 web/app/signin/components/__tests__/social-auth.spec.tsx create mode 100644 web/app/signin/invite-settings/__tests__/page.spec.tsx create mode 100644 web/app/signup/set-password/__tests__/page.spec.tsx diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index f6b8aedf22..e1f3f0eaeb 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -3,7 +3,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from configs import dify_config -from constants.languages import languages +from constants.languages import get_valid_language, languages from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( @@ -15,11 +15,12 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from libs.helper import EmailStr, extract_remote_ip +from libs.helper import timezone as validate_timezone_string from libs.password import valid_password from models import Account from services.account_service import AccountService from services.billing_service import BillingService -from services.errors.account import AccountNotFoundError, AccountRegisterError +from services.errors.account import AccountRegisterError from ..error import AccountInFreezeError, EmailSendIpLimitError from ..wraps import email_password_login_enabled, email_register_enabled, setup_required @@ -40,12 +41,21 @@ class EmailRegisterResetPayload(BaseModel): token: str = Field(...) new_password: str = Field(...) password_confirm: str = Field(...) + language: str | None = Field(default=None) + timezone: str | None = Field(default=None) @field_validator("new_password", "password_confirm") @classmethod def validate_password(cls, value: str) -> str: return valid_password(value) + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str | None) -> str | None: + if value is None: + return None + return validate_timezone_string(value) + register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload) @@ -144,26 +154,32 @@ class EmailRegisterResetApi(Resource): if account: raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + + account = self._create_new_account( + email=normalized_email, + password=args.password_confirm, + timezone=args.timezone, + language=args.language, + ) + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email: str, password: str) -> Account | None: - # Create new account if allowed - account = None + def _create_new_account( + self, + email: str, + password: str, + timezone: str | None = None, + language: str | None = None, + ) -> Account: try: - account = AccountService.create_account_and_tenant( + return AccountService.create_account_and_tenant( email=email, name=email, password=password, - interface_language=languages[0], + interface_language=get_valid_language(language), + timezone=timezone, ) except AccountRegisterError: raise AccountInFreezeError() - - return account diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 19c98f3a1a..3121470b84 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -3,7 +3,7 @@ import logging import flask_login from flask import make_response, request from flask_restx import Resource -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Unauthorized import services @@ -34,6 +34,7 @@ from controllers.console.wraps import ( ) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip +from libs.helper import timezone as validate_timezone_string from libs.login import current_account_with_tenant from libs.token import ( clear_access_token_from_cookie, @@ -69,6 +70,14 @@ class EmailCodeLoginPayload(BaseModel): code: str = Field(...) token: str = Field(...) language: str | None = Field(default=None) + timezone: str | None = Field(default=None) + + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str | None) -> str | None: + if value is None: + return None + return validate_timezone_string(value) register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload) @@ -288,6 +297,7 @@ class EmailCodeLoginApi(Resource): email=user_email, name=user_email, interface_language=get_valid_language(language), + timezone=args.timezone, ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d31fb4a46c..2254fa4981 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.helper import timezone as validate_timezone_string +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state from libs.token import ( set_access_token_to_cookie, set_csrf_token_to_cookie, @@ -53,6 +54,31 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +def _validated_timezone(value: str | None) -> str | None: + if not value: + return None + try: + return validate_timezone_string(value) + except ValueError: + return None + + +def _validated_language(value: str | None) -> str | None: + if value and value in languages: + return value + return None + + +def _preferred_interface_language(language: str | None = None) -> str: + if language: + return language + + preferred_lang = request.accept_languages.best_match(languages) + if preferred_lang and preferred_lang in languages: + return preferred_lang + return languages[0] + + @console_ns.route("/oauth/login/") class OAuthLogin(Resource): @console_ns.doc("oauth_login") @@ -64,13 +90,19 @@ class OAuthLogin(Resource): @console_ns.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None + timezone = _validated_timezone(request.args.get("timezone") or None) + language = _validated_language(request.args.get("language") or None) OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) + auth_url = oauth_provider.get_authorization_url( + invite_token=invite_token, + timezone=timezone, + language=language, + ) return redirect(auth_url) @@ -96,9 +128,10 @@ class OAuthCallback(Resource): code = request.args.get("code") state = request.args.get("state") - invite_token = None - if state: - invite_token = state + oauth_state = decode_oauth_state(state) + invite_token = oauth_state.get("invite_token") + timezone = _validated_timezone(oauth_state.get("timezone")) + language = _validated_language(oauth_state.get("language")) if not code: return {"error": "Authorization code is required"}, 400 @@ -129,7 +162,7 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: - account, oauth_new_user = _generate_account(provider, user_info) + account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language) except AccountNotFoundError: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): @@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> return account -def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]: +def _generate_account( + provider: str, + user_info: OAuthUserInfo, + timezone: str | None = None, + language: str | None = None, +) -> tuple[Account, bool]: # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) oauth_new_user = False @@ -211,26 +249,19 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, "30 days and is temporarily unavailable for new account registration" ) ) - else: - raise AccountRegisterError(description=("Invalid email or password")) + raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" + interface_language = _preferred_interface_language(language) account = RegisterService.register( email=normalized_email, name=account_name, password=None, open_id=user_info.id, provider=provider, + language=interface_language, + timezone=timezone, ) - # Set interface language - preferred_lang = request.accept_languages.best_match(languages) - if preferred_lang and preferred_lang in languages: - interface_language = preferred_lang - else: - interface_language = languages[0] - account.interface_language = interface_language - db.session.commit() - # Link account AccountService.link_account_integrate(provider, user_info.id, account) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 3daaa038e0..309f2aa812 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,3 +1,6 @@ +import base64 +import binascii +import json import logging import urllib.parse from dataclasses import dataclass @@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False): access_token: str +class OAuthState(TypedDict, total=False): + invite_token: str + timezone: str + language: str + + class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool @@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict): ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse) +OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState) GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo) GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord]) GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo) @@ -58,6 +68,37 @@ class OAuthUserInfo: email: str +def encode_oauth_state( + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, +) -> str | None: + state: OAuthState = {} + if invite_token: + state["invite_token"] = invite_token + if timezone: + state["timezone"] = timezone + if language: + state["language"] = language + if not state: + return None + + raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=") + + +def decode_oauth_state(state: str | None) -> OAuthState: + if not state: + return {} + + try: + padded_state = state + "=" * (-len(state) % 4) + raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8") + return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state)) + except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError): + return {} + + def _json_object(response: httpx.Response) -> JsonObject: return JSON_OBJECT_ADAPTER.validate_python(response.json()) @@ -76,7 +117,12 @@ class OAuth: self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: raise NotImplementedError() def get_access_token(self, code: str) -> str: @@ -99,14 +145,20 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": "user:email", # Request only basic user information } - if invite_token: - params["state"] = invite_token + state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language) + if state: + params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str) -> str: @@ -186,15 +238,21 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: params = { "client_id": self.client_id, "response_type": "code", "redirect_uri": self.redirect_uri, "scope": "openid email", } - if invite_token: - params["state"] = invite_token + state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language) + if state: + params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str) -> str: diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index 88410d6507..b2f5bb1215 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive. | code | string | | Yes | | email | string | | Yes | | language | | | No | +| timezone | | | No | | token | string | | Yes | #### EmailPayload @@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | +| language | | | No | | new_password | string | | Yes | | password_confirm | string | | Yes | +| timezone | | | No | | token | string | | Yes | #### EmailRegisterSendPayload diff --git a/api/services/account_service.py b/api/services/account_service.py index 744c17d126..6533526b60 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -29,6 +29,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback from libs.datetime_utils import naive_utc_now from libs.helper import RateLimiter, TokenManager +from libs.helper import timezone as validate_timezone from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair @@ -271,8 +272,9 @@ class AccountService: password: str | None = None, interface_theme: str = "light", is_setup: bool | None = False, + timezone: str | None = None, ) -> Account: - """create account""" + """Create an account, preferring explicit user timezone over language-derived defaults.""" if not FeatureService.get_system_features().is_allow_register and not is_setup: from controllers.console.error import AccountNotFound @@ -302,6 +304,10 @@ class AccountService: password_to_set = base64_password_hashed salt_to_set = base64_salt + resolved_timezone = language_timezone_mapping.get(interface_language, "UTC") + if timezone is not None: + resolved_timezone = validate_timezone(timezone) + account = Account( name=name, email=email, @@ -309,7 +315,7 @@ class AccountService: password_salt=salt_to_set, interface_language=interface_language, interface_theme=interface_theme, - timezone=language_timezone_mapping.get(interface_language, "UTC"), + timezone=resolved_timezone, ) db.session.add(account) @@ -318,11 +324,15 @@ class AccountService: @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: str | None = None + email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None ) -> Account: - """create account""" + """Create an account and owner workspace.""" account = AccountService.create_account( - email=email, name=name, interface_language=interface_language, password=password + email=email, + name=name, + interface_language=interface_language, + password=password, + timezone=timezone, ) try: @@ -1474,8 +1484,8 @@ class RegisterService: @classmethod def register( cls, - email, - name, + email: str, + name: str, password: str | None = None, open_id: str | None = None, provider: str | None = None, @@ -1483,16 +1493,19 @@ class RegisterService: status: AccountStatus | None = None, is_setup: bool | None = False, create_workspace_required: bool | None = True, + timezone: str | None = None, ) -> Account: - db.session.begin_nested() """Register account""" + db.session.begin_nested() try: + interface_language = get_valid_language(language) account = AccountService.create_account( email=email, name=name, - interface_language=get_valid_language(language), + interface_language=interface_language, password=password, is_setup=is_setup, + timezone=timezone, ) account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 1fcce9ca44..bb7921a5f4 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -143,7 +143,118 @@ class TestEmailRegisterResetApi: response = EmailRegisterResetApi().post() assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} - mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone=None, + language=None, + ) + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_passes_timezone_to_new_account( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_create_account, + mock_login, + mock_reset_login_rate, + app: Flask, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "timezone": "Asia/Shanghai", + }, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone="Asia/Shanghai", + language=None, + ) + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_passes_language_to_new_account( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_create_account, + mock_login, + mock_reset_login_rate, + app: Flask, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "language": "zh-Hans", + }, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone=None, + language="zh-Hans", + ) mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 55b6a919d8..a5ae83739c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -14,7 +14,7 @@ from controllers.console.auth.oauth import ( _get_account_by_openid_or_email, get_oauth_providers, ) -from libs.oauth import OAuthUserInfo +from libs.oauth import OAuthUserInfo, encode_oauth_state from models.account import AccountStatus from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -101,7 +101,55 @@ class TestOAuthLogin: with app.test_request_context(f"/auth/oauth/github?{query_string}"): resource.get("github") - mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token) + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=expected_token, + timezone=None, + language=None, + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") + + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.redirect") + def test_should_pass_timezone_to_oauth_state( + self, + mock_redirect, + mock_get_providers, + resource, + app: Flask, + mock_oauth_provider, + ): + mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None} + + with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"): + resource.get("github") + + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone="Asia/Shanghai", + language=None, + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") + + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.redirect") + def test_should_pass_language_to_oauth_state( + self, + mock_redirect, + mock_get_providers, + resource, + app: Flask, + mock_oauth_provider, + ): + mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None} + + with app.test_request_context("/auth/oauth/github?language=zh-Hans"): + resource.get("github") + + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone=None, + language="zh-Hans", + ) mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") @pytest.mark.parametrize( @@ -229,7 +277,8 @@ class TestOAuthCallback: mock_register_service.is_valid_invite_token.return_value = True mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} - with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai") + with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"): resource.get("github") mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") @@ -488,7 +537,13 @@ class TestAccountGeneration: if should_create: mock_register_service.register.assert_called_once_with( - email="test@example.com", name="Test User", password=None, open_id="123", provider="github" + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="en-US", + timezone=None, ) else: mock_register_service.register.assert_not_called() @@ -515,7 +570,75 @@ class TestAccountGeneration: _generate_account("github", user_info) mock_register_service.register.assert_called_once_with( - email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + email="upper@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="en-US", + timezone=None, + ) + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + def test_should_register_with_browser_timezone( + self, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app: Flask, + user_info, + ): + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}): + _generate_account("github", user_info, timezone="Asia/Shanghai") + + mock_register_service.register.assert_called_once_with( + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="zh-Hans", + timezone="Asia/Shanghai", + ) + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + def test_should_register_with_state_language( + self, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app: Flask, + user_info, + ): + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + _generate_account("github", user_info, language="zh-Hans") + + mock_register_service.register.assert_called_once_with( + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="zh-Hans", + timezone=None, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py new file mode 100644 index 0000000000..df282880af --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from controllers.console.auth.email_register import EmailRegisterResetApi, EmailRegisterResetPayload + + +@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant") +def test_create_new_account_uses_requested_language(mock_create_account): + account = MagicMock() + mock_create_account.return_value = account + + result = EmailRegisterResetApi()._create_new_account( + "invitee@example.com", + "ValidPass123!", + timezone="Asia/Shanghai", + language="zh-Hans", + ) + + assert result is account + mock_create_account.assert_called_once_with( + email="invitee@example.com", + name="invitee@example.com", + password="ValidPass123!", + interface_language="zh-Hans", + timezone="Asia/Shanghai", + ) + + +def test_reset_payload_rejects_invalid_timezone(): + with pytest.raises(ValidationError): + EmailRegisterResetPayload.model_validate( + { + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "timezone": "", + } + ) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 102af9b250..fa23942c65 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -13,9 +13,10 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from pydantic import ValidationError from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError -from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi +from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginPayload, EmailCodeLoginSendEmailApi from controllers.console.error import ( AccountInFreezeError, AccountNotFound, @@ -31,6 +32,18 @@ def encode_code(code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode() +def test_email_code_login_payload_rejects_invalid_timezone(): + with pytest.raises(ValidationError): + EmailCodeLoginPayload.model_validate( + { + "email": "newuser@example.com", + "code": "123456", + "token": "token-123", + "timezone": "", + } + ) + + class TestEmailCodeLoginSendEmailApi: """Test cases for sending email verification codes.""" @@ -342,6 +355,7 @@ class TestEmailCodeLoginApi: "code": encode_code("123456"), "token": "valid_token", "language": "en-US", + "timezone": "Asia/Shanghai", }, ): api = EmailCodeLoginApi() @@ -349,7 +363,12 @@ class TestEmailCodeLoginApi: # Assert assert response.json["result"] == "success" - mock_create_account.assert_called_once() + mock_create_account.assert_called_once_with( + email="newuser@example.com", + name="newuser@example.com", + interface_language="en-US", + timezone="Asia/Shanghai", + ) @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py new file mode 100644 index 0000000000..36c707dbf9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py @@ -0,0 +1,123 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.oauth import OAuthLogin, _generate_account +from libs.oauth import OAuthUserInfo +from services.errors.account import AccountRegisterError + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@patch("controllers.console.auth.oauth.redirect") +@patch("controllers.console.auth.oauth.get_oauth_providers") +def test_oauth_login_passes_language_and_timezone_to_authorization_url( + mock_get_oauth_providers, + mock_redirect, + app: Flask, +): + oauth_provider = MagicMock() + oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..." + mock_get_oauth_providers.return_value = {"github": oauth_provider} + + with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"): + OAuthLogin().get("github") + + oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone="Asia/Shanghai", + language="zh-Hans", + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...") + + +@patch("controllers.console.auth.oauth.AccountService.link_account_integrate") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_registers_with_browser_timezone( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_link_account, + app: Flask, +): + account = MagicMock() + mock_register_service.register.return_value = account + mock_feature_service.get_system_features.return_value.is_allow_register = True + user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com") + + with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}): + result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai") + + assert result is account + assert oauth_new_user is True + mock_register_service.register.assert_called_once_with( + email="user@example.com", + name="Test User", + password=None, + open_id="github-123", + provider="github", + language="zh-Hans", + timezone="Asia/Shanghai", + ) + mock_link_account.assert_called_once_with("github", "github-123", account) + + +@patch("controllers.console.auth.oauth.AccountService.link_account_integrate") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_prefers_state_language_over_accept_language( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_link_account, + app: Flask, +): + account = MagicMock() + mock_register_service.register.return_value = account + mock_feature_service.get_system_features.return_value.is_allow_register = True + user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com") + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + _generate_account("github", user_info, language="zh-Hans") + + mock_register_service.register.assert_called_once_with( + email="user@example.com", + name="Test User", + password=None, + open_id="github-123", + provider="github", + language="zh-Hans", + timezone=None, + ) + mock_link_account.assert_called_once_with("github", "github-123", account) + + +@patch("controllers.console.auth.oauth.dify_config") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_rejects_new_user_when_registration_disabled( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_config, + app: Flask, +): + mock_feature_service.get_system_features.return_value.is_allow_register = False + mock_config.BILLING_ENABLED = False + user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com") + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + with pytest.raises(AccountRegisterError): + _generate_account("github", user_info) + + mock_register_service.register.assert_not_called() diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py index 7b7f086dac..1c0066ed9a 100644 --- a/api/tests/unit_tests/libs/test_oauth_base.py +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -1,6 +1,6 @@ import pytest -from libs.oauth import OAuth +from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state def test_oauth_base_methods_raise_not_implemented(): @@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented(): with pytest.raises(NotImplementedError): oauth._transform_user_info({}) + + +def test_oauth_state_round_trips_invite_token_timezone_and_language(): + state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans") + + assert decode_oauth_state(state) == { + "invite_token": "invite-123", + "timezone": "Asia/Shanghai", + "language": "zh-Hans", + } + + +def test_oauth_state_returns_empty_payload_for_invalid_state(): + assert decode_oauth_state("invalid-state") == {} diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index 830284e697..b3ecc5a06d 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state class BaseOAuthTest: @@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest): return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"]) @pytest.mark.parametrize( - ("invite_token", "expected_state"), + ("invite_token", "timezone", "language", "expected_state"), [ - (None, None), - ("test_invite_token", "test_invite_token"), - ("", None), + (None, None, None, None), + ("test_invite_token", None, None, {"invite_token": "test_invite_token"}), + ("", None, None, None), + (None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}), + (None, None, "zh-Hans", {"language": "zh-Hans"}), + ( + "test_invite_token", + "Asia/Shanghai", + "zh-Hans", + {"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"}, + ), ], ) - def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state): - url = oauth.get_authorization_url(invite_token) + def test_should_generate_authorization_url_correctly( + self, oauth, oauth_config, invite_token, timezone, language, expected_state + ): + url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language) parsed, params = self.parse_auth_url(url) assert parsed.scheme == "https" @@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert params["scope"][0] == "user:email" if expected_state: - assert params["state"][0] == expected_state + assert decode_oauth_state(params["state"][0]) == expected_state else: assert "state" not in params @@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest): return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"]) @pytest.mark.parametrize( - ("invite_token", "expected_state"), + ("invite_token", "timezone", "language", "expected_state"), [ - (None, None), - ("test_invite_token", "test_invite_token"), - ("", None), + (None, None, None, None), + ("test_invite_token", None, None, {"invite_token": "test_invite_token"}), + ("", None, None, None), + (None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}), + (None, None, "zh-Hans", {"language": "zh-Hans"}), + ( + "test_invite_token", + "Asia/Shanghai", + "zh-Hans", + {"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"}, + ), ], ) - def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state): - url = oauth.get_authorization_url(invite_token) + def test_should_generate_authorization_url_correctly( + self, oauth, oauth_config, invite_token, timezone, language, expected_state + ): + url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language) parsed, params = self.parse_auth_url(url) assert parsed.scheme == "https" @@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest): assert params["scope"][0] == "openid email" if expected_state: - assert params["state"][0] == expected_state + assert decode_oauth_state(params["state"][0]) == expected_state else: assert "state" not in params diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 02013392fc..8c554e012d 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -260,7 +260,7 @@ class TestAccountService: assert result.interface_theme == "light" assert result.password is not None assert result.password_salt is not None - assert result.timezone is not None + assert result.timezone == "America/New_York" # Verify database operations mock_db_dependencies["db"].session.add.assert_called_once() @@ -271,7 +271,28 @@ class TestAccountService: assert added_account.interface_theme == "light" assert added_account.password is not None assert added_account.password_salt is not None - assert added_account.timezone is not None + assert added_account.timezone == "America/New_York" + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_create_account_uses_explicit_timezone( + self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies + ): + """Test account creation prefers explicit browser timezone.""" + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_password_dependencies["hash_password"].return_value = b"hashed_password" + + result = AccountService.create_account( + email="test@example.com", + name="Test User", + interface_language="en-US", + password="password123", + timezone="Asia/Shanghai", + ) + + assert result.timezone == "Asia/Shanghai" + added_account = mock_db_dependencies["db"].session.add.call_args[0][0] + assert added_account.timezone == "Asia/Shanghai" self._assert_database_operations_called(mock_db_dependencies["db"]) def test_create_account_registration_disabled(self, mock_external_service_dependencies): @@ -1221,6 +1242,7 @@ class TestRegisterService: interface_language="en-US", password="password123", is_setup=False, + timezone=None, ) mock_create_tenant.assert_called_once_with("Test User's Workspace") mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner") diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index 42024c561b..c90f426fe2 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -13,6 +13,7 @@ import { useLocale } from '@/context/i18n' import { useRouter, useSearchParams } from '@/next/navigation' import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common' import { encryptVerificationCode } from '@/utils/encryption' +import { getBrowserTimezone } from '@/utils/timezone' import { resolvePostLoginRedirect } from '../utils/post-login-redirect' export default function CheckCode() { @@ -39,7 +40,13 @@ export default function CheckCode() { return } setIsLoading(true) - const ret = await emailLoginWithCode({ email, code: encryptVerificationCode(code), token, language }) + const ret = await emailLoginWithCode({ + email, + code: encryptVerificationCode(code), + token, + language, + timezone: getBrowserTimezone(), + }) if (ret.result === 'success') { // Track login success event trackEvent('user_login_success', { diff --git a/web/app/signin/components/__tests__/social-auth.spec.tsx b/web/app/signin/components/__tests__/social-auth.spec.tsx new file mode 100644 index 0000000000..bf1183f324 --- /dev/null +++ b/web/app/signin/components/__tests__/social-auth.spec.tsx @@ -0,0 +1,86 @@ +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useLocale } from '@/context/i18n' +import { useSearchParams } from '@/next/navigation' +import { getBrowserTimezone } from '@/utils/timezone' +import SocialAuth from '../social-auth' + +vi.mock('@/next/navigation', () => ({ + useSearchParams: vi.fn(), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), +})) + +const mockUseSearchParams = vi.mocked(useSearchParams) +const mockUseLocale = vi.mocked(useLocale) +const mockGetBrowserTimezone = vi.mocked(getBrowserTimezone) + +describe('SocialAuth', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType) + mockUseLocale.mockReturnValue('zh-Hans') + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + }) + + describe('Rendering', () => { + it('should render oauth provider links', () => { + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toBeInTheDocument() + }) + }) + + describe('OAuth params', () => { + it('should include browser timezone and locale in oauth links', () => { + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute( + 'href', + expect.stringContaining('timezone=Asia%2FShanghai'), + ) + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute( + 'href', + expect.stringContaining('language=zh-Hans'), + ) + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute( + 'href', + expect.stringContaining('timezone=Asia%2FShanghai'), + ) + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute( + 'href', + expect.stringContaining('language=zh-Hans'), + ) + }) + + it('should preserve invite token when adding timezone', () => { + mockUseSearchParams.mockReturnValue( + new URLSearchParams('invite_token=invite-123') as unknown as ReturnType, + ) + + render() + + const githubLink = screen.getByRole('link', { name: 'login.withGitHub' }) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('invite_token=invite-123')) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('timezone=Asia%2FShanghai')) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('language=zh-Hans')) + }) + }) + + describe('Edge Cases', () => { + it('should omit timezone when browser timezone is unavailable', () => { + mockGetBrowserTimezone.mockReturnValue(undefined) + + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' }).getAttribute('href')).not.toContain('timezone=') + }) + }) +}) diff --git a/web/app/signin/components/social-auth.tsx b/web/app/signin/components/social-auth.tsx index 09fa528d1f..17455cebf4 100644 --- a/web/app/signin/components/social-auth.tsx +++ b/web/app/signin/components/social-auth.tsx @@ -2,8 +2,10 @@ import { Button } from '@langgenius/dify-ui/button' import { cn } from '@langgenius/dify-ui/cn' import { useTranslation } from 'react-i18next' import { API_PREFIX } from '@/config' +import { useLocale } from '@/context/i18n' import { useSearchParams } from '@/next/navigation' import { getPurifyHref } from '@/utils' +import { getBrowserTimezone } from '@/utils/timezone' import style from '../page.module.css' type SocialAuthProps = { @@ -13,11 +15,19 @@ type SocialAuthProps = { export default function SocialAuth(props: SocialAuthProps) { const { t } = useTranslation() const searchParams = useSearchParams() + const locale = useLocale() const getOAuthLink = (href: string) => { const url = getPurifyHref(`${API_PREFIX}${href}`) - if (searchParams.has('invite_token')) - return `${url}?${searchParams.toString()}` + const params = new URLSearchParams(searchParams.toString()) + const timezone = getBrowserTimezone() + if (timezone) + params.set('timezone', timezone) + params.set('language', locale) + + const query = params.toString() + if (query) + return `${url}?${query}` return url } diff --git a/web/app/signin/invite-settings/__tests__/page.spec.tsx b/web/app/signin/invite-settings/__tests__/page.spec.tsx new file mode 100644 index 0000000000..12e6f4dfb7 --- /dev/null +++ b/web/app/signin/invite-settings/__tests__/page.spec.tsx @@ -0,0 +1,139 @@ +import type { MockedFunction } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' +import { activateMember } from '@/service/common' +import { useInvitationCheck } from '@/service/use-common' +import { getBrowserTimezone } from '@/utils/timezone' +import InviteSettingsPage from '../page' + +vi.mock('@tanstack/react-query', async () => { + const actual = await vi.importActual('@tanstack/react-query') + return { + ...actual, + useSuspenseQuery: vi.fn(() => ({ + data: { + branding: { + enabled: true, + }, + }, + })), + } +}) + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/i18n-config', () => ({ + i18n: { + defaultLocale: 'en-US', + }, + setLocaleOnClient: vi.fn(() => Promise.resolve()), +})) + +vi.mock('@/next/navigation', () => ({ + useRouter: vi.fn(), + useSearchParams: vi.fn(), +})) + +vi.mock('@/service/common', () => ({ + activateMember: vi.fn(), +})) + +vi.mock('@/service/use-common', () => ({ + useInvitationCheck: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), + timezones: [ + { value: 'Asia/Shanghai', name: 'Asia/Shanghai' }, + { value: 'America/Los_Angeles', name: 'America/Los_Angeles' }, + ], +})) + +vi.mock('../utils/post-login-redirect', () => ({ + resolvePostLoginRedirect: vi.fn(() => null), +})) + +const mockReplace = vi.fn() +const mockRefetch = vi.fn() + +const mockUseLocale = useLocale as unknown as MockedFunction +const mockUseRouter = useRouter as unknown as MockedFunction +const mockUseSearchParams = useSearchParams as unknown as MockedFunction +const mockActivateMember = activateMember as unknown as MockedFunction +const mockUseInvitationCheck = useInvitationCheck as unknown as MockedFunction +const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction + +describe('InviteSettingsPage', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseLocale.mockReturnValue('zh-Hans') + mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType) + mockUseSearchParams.mockReturnValue( + new URLSearchParams('invite_token=invite-token') as unknown as ReturnType, + ) + mockUseInvitationCheck.mockReturnValue({ + data: { + is_valid: true, + data: { + workspace_name: 'Acme', + workspace_id: 'workspace-id', + email: 'invitee@example.com', + }, + }, + refetch: mockRefetch, + } as unknown as ReturnType) + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + mockActivateMember.mockResolvedValue({ result: 'success' }) + }) + + describe('Activation payload', () => { + it('should default language to the current UI locale', async () => { + render() + + fireEvent.change(screen.getByLabelText('login.name'), { + target: { value: 'Invitee' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' })) + + await waitFor(() => { + expect(mockActivateMember).toHaveBeenCalledWith({ + url: '/activate', + body: { + token: 'invite-token', + name: 'Invitee', + interface_language: 'zh-Hans', + timezone: 'Asia/Shanghai', + }, + }) + }) + }) + + it('should fall back to configured default locale when current locale is unsupported', async () => { + mockUseLocale.mockReturnValue('unsupported-locale' as ReturnType) + + render() + + fireEvent.change(screen.getByLabelText('login.name'), { + target: { value: 'Invitee' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' })) + + await waitFor(() => { + expect(mockActivateMember).toHaveBeenCalledWith({ + url: '/activate', + body: { + token: 'invite-token', + name: 'Invitee', + interface_language: 'en-US', + timezone: 'Asia/Shanghai', + }, + }) + }) + }) + }) +}) diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index af0ff5e07a..ddde11b940 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -11,14 +11,15 @@ import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' import { LICENSE_LINK } from '@/constants/link' -import { setLocaleOnClient } from '@/i18n-config' -import { languages, LanguagesSupported } from '@/i18n-config/language' +import { useLocale } from '@/context/i18n' +import { i18n, setLocaleOnClient } from '@/i18n-config' +import { languages } from '@/i18n-config/language' import Link from '@/next/link' import { useRouter, useSearchParams } from '@/next/navigation' import { activateMember } from '@/service/common' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useInvitationCheck } from '@/service/use-common' -import { timezones } from '@/utils/timezone' +import { getBrowserTimezone, timezones } from '@/utils/timezone' import { resolvePostLoginRedirect } from '../utils/post-login-redirect' type LanguageSelectOption = { @@ -43,15 +44,23 @@ const TIMEZONE_OPTIONS: TimezoneSelectOption[] = timezones.map(item => ({ name: item.name, })) +const getInitialLanguage = (locale: Locale): Locale => { + if (LANGUAGE_OPTIONS.some(item => item.value === locale)) + return locale + + return i18n.defaultLocale +} + export default function InviteSettingsPage() { const { t } = useTranslation() const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const router = useRouter() const searchParams = useSearchParams() const token = decodeURIComponent(searchParams.get('invite_token') as string) + const locale = useLocale() const [name, setName] = useState('') - const [language, setLanguage] = useState(LanguagesSupported[0]) - const [timezone, setTimezone] = useState(() => Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles') + const [language, setLanguage] = useState(() => getInitialLanguage(locale)) + const [timezone, setTimezone] = useState(() => getBrowserTimezone() || 'America/Los_Angeles') const selectedLanguage = LANGUAGE_OPTIONS.find(item => item.value === language) const selectedTimezone = TIMEZONE_OPTIONS.find(item => item.value === timezone) diff --git a/web/app/signup/set-password/__tests__/page.spec.tsx b/web/app/signup/set-password/__tests__/page.spec.tsx new file mode 100644 index 0000000000..e68694db15 --- /dev/null +++ b/web/app/signup/set-password/__tests__/page.spec.tsx @@ -0,0 +1,85 @@ +import type { MockedFunction } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' +import { useMailRegister } from '@/service/use-common' +import { getBrowserTimezone } from '@/utils/timezone' +import ChangePasswordForm from '../page' + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/next/navigation', () => ({ + useRouter: vi.fn(), + useSearchParams: vi.fn(), +})) + +vi.mock('@/service/use-common', () => ({ + useMailRegister: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), +})) + +vi.mock('@/utils/gtag', () => ({ + sendGAEvent: vi.fn(), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) + +vi.mock('@/utils/create-app-tracking', () => ({ + rememberCreateAppExternalAttribution: vi.fn(), +})) + +const mockRegister = vi.fn() +const mockReplace = vi.fn() + +const mockUseLocale = useLocale as unknown as MockedFunction +const mockUseSearchParams = useSearchParams as unknown as MockedFunction +const mockUseRouter = useRouter as unknown as MockedFunction +const mockUseMailRegister = useMailRegister as unknown as MockedFunction +const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction + +describe('Signup Set Password Page', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseLocale.mockReturnValue('zh-Hans') + mockUseSearchParams.mockReturnValue(new URLSearchParams('token=register-token') as unknown as ReturnType) + mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType) + mockUseMailRegister.mockReturnValue({ + mutateAsync: mockRegister, + isPending: false, + } as unknown as ReturnType) + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + mockRegister.mockResolvedValue({ result: 'fail', data: {} }) + }) + + describe('Registration payload', () => { + it('should submit locale and browser timezone when setting password', async () => { + render() + + fireEvent.change(screen.getByLabelText('common.account.newPassword'), { + target: { value: 'ValidPass123!' }, + }) + fireEvent.change(screen.getByLabelText('common.account.confirmPassword'), { + target: { value: 'ValidPass123!' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.changePasswordBtn' })) + + await waitFor(() => { + expect(mockRegister).toHaveBeenCalledWith({ + token: 'register-token', + new_password: 'ValidPass123!', + password_confirm: 'ValidPass123!', + language: 'zh-Hans', + timezone: 'Asia/Shanghai', + }) + }) + }) + }) +}) diff --git a/web/app/signup/set-password/page.tsx b/web/app/signup/set-password/page.tsx index a8eb883078..534b6b55d4 100644 --- a/web/app/signup/set-password/page.tsx +++ b/web/app/signup/set-password/page.tsx @@ -9,10 +9,12 @@ import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import Input from '@/app/components/base/input' import { validPassword } from '@/config' +import { useLocale } from '@/context/i18n' import { useRouter, useSearchParams } from '@/next/navigation' import { useMailRegister } from '@/service/use-common' import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' +import { getBrowserTimezone } from '@/utils/timezone' const parseUtmInfo = () => { const utmInfoStr = Cookies.get('utm_info') @@ -32,6 +34,7 @@ const ChangePasswordForm = () => { const router = useRouter() const searchParams = useSearchParams() const token = decodeURIComponent(searchParams.get('token') || '') + const locale = useLocale() const [password, setPassword] = useState('') const [confirmPassword, setConfirmPassword] = useState('') @@ -65,6 +68,8 @@ const ChangePasswordForm = () => { token, new_password: password, password_confirm: confirmPassword, + language: locale, + timezone: getBrowserTimezone(), }) const { result } = res as MailRegisterResponse if (result === 'success') { @@ -88,7 +93,7 @@ const ChangePasswordForm = () => { catch (error) { console.error(error) } - }, [password, token, valid, confirmPassword, register]) + }, [password, token, valid, confirmPassword, register, locale]) return (
=> post('/email-code-login', { body: { email, language } }) -export const emailLoginWithCode = (data: { email: string, code: string, token: string, language: string }): Promise => +export const emailLoginWithCode = (data: { + email: string + code: string + token: string + language: string + timezone?: string +}): Promise => post('/email-code-login/validity', { body: data }) export const sendResetPasswordCode = (email: string, language = 'en-US'): Promise => diff --git a/web/service/use-common.ts b/web/service/use-common.ts index 0154be09ff..9d17585bfa 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -178,7 +178,13 @@ export type MailRegisterResponse = { result: string, data: {} } export const useMailRegister = () => { return useMutation({ mutationKey: [NAME_SPACE, 'mail-register'], - mutationFn: (body: { token: string, new_password: string, password_confirm: string }) => { + mutationFn: (body: { + token: string + new_password: string + password_confirm: string + language?: string + timezone?: string + }) => { return post('/email-register', { body }) }, }) diff --git a/web/utils/timezone.ts b/web/utils/timezone.ts index e854ae7d5a..38e2e04d50 100644 --- a/web/utils/timezone.ts +++ b/web/utils/timezone.ts @@ -5,3 +5,10 @@ type Item = { name: string } export const timezones: Item[] = tz + +export const getBrowserTimezone = () => { + if (typeof Intl === 'undefined') + return undefined + + return Intl.DateTimeFormat().resolvedOptions().timeZone || undefined +}