feat: initialize user timezone and language from browser (#36170)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
yyh
2026-05-15 16:12:52 +08:00
committed by GitHub
parent ff02636a4b
commit a252fbddfa
24 changed files with 1050 additions and 87 deletions

View File

@@ -3,7 +3,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
from constants.languages import languages
from constants.languages import get_valid_language, languages
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
@@ -15,11 +15,12 @@ from controllers.console.auth.error import (
PasswordMismatchError,
)
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.account import AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
@@ -40,12 +41,21 @@ class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
language: str | None = Field(default=None)
timezone: str | None = Field(default=None)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str | None) -> str | None:
if value is None:
return None
return validate_timezone_string(value)
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
@@ -144,26 +154,32 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
account = self._create_new_account(
email=normalized_email,
password=args.password_confirm,
timezone=args.timezone,
language=args.language,
)
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
def _create_new_account(
self,
email: str,
password: str,
timezone: str | None = None,
language: str | None = None,
) -> Account:
try:
account = AccountService.create_account_and_tenant(
return AccountService.create_account_and_tenant(
email=email,
name=email,
password=password,
interface_language=languages[0],
interface_language=get_valid_language(language),
timezone=timezone,
)
except AccountRegisterError:
raise AccountInFreezeError()
return account

View File

@@ -3,7 +3,7 @@ import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
@@ -34,6 +34,7 @@ from controllers.console.wraps import (
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@@ -69,6 +70,14 @@ class EmailCodeLoginPayload(BaseModel):
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
timezone: str | None = Field(default=None)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str | None) -> str | None:
if value is None:
return None
return validate_timezone_string(value)
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
@@ -288,6 +297,7 @@ class EmailCodeLoginApi(Resource):
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
timezone=args.timezone,
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()

View File

@@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.helper import timezone as validate_timezone_string
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
@@ -53,6 +54,31 @@ def get_oauth_providers():
return OAUTH_PROVIDERS
def _validated_timezone(value: str | None) -> str | None:
if not value:
return None
try:
return validate_timezone_string(value)
except ValueError:
return None
def _validated_language(value: str | None) -> str | None:
if value and value in languages:
return value
return None
def _preferred_interface_language(language: str | None = None) -> str:
if language:
return language
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
return preferred_lang
return languages[0]
@console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource):
@console_ns.doc("oauth_login")
@@ -64,13 +90,19 @@ class OAuthLogin(Resource):
@console_ns.response(400, "Invalid provider")
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
timezone = _validated_timezone(request.args.get("timezone") or None)
language = _validated_language(request.args.get("language") or None)
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
auth_url = oauth_provider.get_authorization_url(
invite_token=invite_token,
timezone=timezone,
language=language,
)
return redirect(auth_url)
@@ -96,9 +128,10 @@ class OAuthCallback(Resource):
code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
oauth_state = decode_oauth_state(state)
invite_token = oauth_state.get("invite_token")
timezone = _validated_timezone(oauth_state.get("timezone"))
language = _validated_language(oauth_state.get("language"))
if not code:
return {"error": "Authorization code is required"}, 400
@@ -129,7 +162,7 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account, oauth_new_user = _generate_account(provider, user_info)
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
return account
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
def _generate_account(
provider: str,
user_info: OAuthUserInfo,
timezone: str | None = None,
language: str | None = None,
) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
oauth_new_user = False
@@ -211,26 +249,19 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
"30 days and is temporarily unavailable for new account registration"
)
)
else:
raise AccountRegisterError(description=("Invalid email or password"))
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
interface_language = _preferred_interface_language(language)
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
language=interface_language,
timezone=timezone,
)
# Set interface language
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
interface_language = preferred_lang
else:
interface_language = languages[0]
account.interface_language = interface_language
db.session.commit()
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)

View File

@@ -1,3 +1,6 @@
import base64
import binascii
import json
import logging
import urllib.parse
from dataclasses import dataclass
@@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False):
access_token: str
class OAuthState(TypedDict, total=False):
invite_token: str
timezone: str
language: str
class GitHubEmailRecord(TypedDict, total=False):
email: str
primary: bool
@@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict):
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState)
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
@@ -58,6 +68,37 @@ class OAuthUserInfo:
email: str
def encode_oauth_state(
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str | None:
state: OAuthState = {}
if invite_token:
state["invite_token"] = invite_token
if timezone:
state["timezone"] = timezone
if language:
state["language"] = language
if not state:
return None
raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8")
return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=")
def decode_oauth_state(state: str | None) -> OAuthState:
if not state:
return {}
try:
padded_state = state + "=" * (-len(state) % 4)
raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8")
return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state))
except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError):
return {}
def _json_object(response: httpx.Response) -> JsonObject:
return JSON_OBJECT_ADAPTER.validate_python(response.json())
@@ -76,7 +117,12 @@ class OAuth:
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
raise NotImplementedError()
def get_access_token(self, code: str) -> str:
@@ -99,14 +145,20 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
if state:
params["state"] = state
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str) -> str:
@@ -186,15 +238,21 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
if state:
params["state"] = state
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str) -> str:

View File

@@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive.
| code | string | | Yes |
| email | string | | Yes |
| language | | | No |
| timezone | | | No |
| token | string | | Yes |
#### EmailPayload
@@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| language | | | No |
| new_password | string | | Yes |
| password_confirm | string | | Yes |
| timezone | | | No |
| token | string | | Yes |
#### EmailRegisterSendPayload

View File

@@ -29,6 +29,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from libs.helper import RateLimiter, TokenManager
from libs.helper import timezone as validate_timezone
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
@@ -271,8 +272,9 @@ class AccountService:
password: str | None = None,
interface_theme: str = "light",
is_setup: bool | None = False,
timezone: str | None = None,
) -> Account:
"""create account"""
"""Create an account, preferring explicit user timezone over language-derived defaults."""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
from controllers.console.error import AccountNotFound
@@ -302,6 +304,10 @@ class AccountService:
password_to_set = base64_password_hashed
salt_to_set = base64_salt
resolved_timezone = language_timezone_mapping.get(interface_language, "UTC")
if timezone is not None:
resolved_timezone = validate_timezone(timezone)
account = Account(
name=name,
email=email,
@@ -309,7 +315,7 @@ class AccountService:
password_salt=salt_to_set,
interface_language=interface_language,
interface_theme=interface_theme,
timezone=language_timezone_mapping.get(interface_language, "UTC"),
timezone=resolved_timezone,
)
db.session.add(account)
@@ -318,11 +324,15 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: str | None = None
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
) -> Account:
"""create account"""
"""Create an account and owner workspace."""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
email=email,
name=name,
interface_language=interface_language,
password=password,
timezone=timezone,
)
try:
@@ -1474,8 +1484,8 @@ class RegisterService:
@classmethod
def register(
cls,
email,
name,
email: str,
name: str,
password: str | None = None,
open_id: str | None = None,
provider: str | None = None,
@@ -1483,16 +1493,19 @@ class RegisterService:
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
timezone: str | None = None,
) -> Account:
db.session.begin_nested()
"""Register account"""
db.session.begin_nested()
try:
interface_language = get_valid_language(language)
account = AccountService.create_account(
email=email,
name=name,
interface_language=get_valid_language(language),
interface_language=interface_language,
password=password,
is_setup=is_setup,
timezone=timezone,
)
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()

View File

@@ -143,7 +143,118 @@ class TestEmailRegisterResetApi:
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
mock_create_account.assert_called_once_with(
email="invitee@example.com",
password="ValidPass123!",
timezone=None,
language=None,
)
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_passes_timezone_to_new_account(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_create_account,
mock_login,
mock_reset_login_rate,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
"timezone": "Asia/Shanghai",
},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with(
email="invitee@example.com",
password="ValidPass123!",
timezone="Asia/Shanghai",
language=None,
)
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_passes_language_to_new_account(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_create_account,
mock_login,
mock_reset_login_rate,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
"language": "zh-Hans",
},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with(
email="invitee@example.com",
password="ValidPass123!",
timezone=None,
language="zh-Hans",
)
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()

View File

@@ -14,7 +14,7 @@ from controllers.console.auth.oauth import (
_get_account_by_openid_or_email,
get_oauth_providers,
)
from libs.oauth import OAuthUserInfo
from libs.oauth import OAuthUserInfo, encode_oauth_state
from models.account import AccountStatus
from services.account_service import AccountService
from services.errors.account import AccountRegisterError
@@ -101,7 +101,55 @@ class TestOAuthLogin:
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=expected_token,
timezone=None,
language=None,
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_pass_timezone_to_oauth_state(
self,
mock_redirect,
mock_get_providers,
resource,
app: Flask,
mock_oauth_provider,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone="Asia/Shanghai",
language=None,
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_pass_language_to_oauth_state(
self,
mock_redirect,
mock_get_providers,
resource,
app: Flask,
mock_oauth_provider,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
with app.test_request_context("/auth/oauth/github?language=zh-Hans"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone=None,
language="zh-Hans",
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@pytest.mark.parametrize(
@@ -229,7 +277,8 @@ class TestOAuthCallback:
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai")
with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
@@ -488,7 +537,13 @@ class TestAccountGeneration:
if should_create:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="en-US",
timezone=None,
)
else:
mock_register_service.register.assert_not_called()
@@ -515,7 +570,75 @@ class TestAccountGeneration:
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
email="upper@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="en-US",
timezone=None,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_browser_timezone(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
_generate_account("github", user_info, timezone="Asia/Shanghai")
mock_register_service.register.assert_called_once_with(
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_state_language(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
_generate_account("github", user_info, language="zh-Hans")
mock_register_service.register.assert_called_once_with(
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="zh-Hans",
timezone=None,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")

View File

@@ -0,0 +1,40 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from controllers.console.auth.email_register import EmailRegisterResetApi, EmailRegisterResetPayload
@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant")
def test_create_new_account_uses_requested_language(mock_create_account):
account = MagicMock()
mock_create_account.return_value = account
result = EmailRegisterResetApi()._create_new_account(
"invitee@example.com",
"ValidPass123!",
timezone="Asia/Shanghai",
language="zh-Hans",
)
assert result is account
mock_create_account.assert_called_once_with(
email="invitee@example.com",
name="invitee@example.com",
password="ValidPass123!",
interface_language="zh-Hans",
timezone="Asia/Shanghai",
)
def test_reset_payload_rejects_invalid_timezone():
with pytest.raises(ValidationError):
EmailRegisterResetPayload.model_validate(
{
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
"timezone": "",
}
)

View File

@@ -13,9 +13,10 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginPayload, EmailCodeLoginSendEmailApi
from controllers.console.error import (
AccountInFreezeError,
AccountNotFound,
@@ -31,6 +32,18 @@ def encode_code(code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode()
def test_email_code_login_payload_rejects_invalid_timezone():
with pytest.raises(ValidationError):
EmailCodeLoginPayload.model_validate(
{
"email": "newuser@example.com",
"code": "123456",
"token": "token-123",
"timezone": "",
}
)
class TestEmailCodeLoginSendEmailApi:
"""Test cases for sending email verification codes."""
@@ -342,6 +355,7 @@ class TestEmailCodeLoginApi:
"code": encode_code("123456"),
"token": "valid_token",
"language": "en-US",
"timezone": "Asia/Shanghai",
},
):
api = EmailCodeLoginApi()
@@ -349,7 +363,12 @@ class TestEmailCodeLoginApi:
# Assert
assert response.json["result"] == "success"
mock_create_account.assert_called_once()
mock_create_account.assert_called_once_with(
email="newuser@example.com",
name="newuser@example.com",
interface_language="en-US",
timezone="Asia/Shanghai",
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")

View File

@@ -0,0 +1,123 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import OAuthLogin, _generate_account
from libs.oauth import OAuthUserInfo
from services.errors.account import AccountRegisterError
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.oauth.redirect")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_oauth_login_passes_language_and_timezone_to_authorization_url(
mock_get_oauth_providers,
mock_redirect,
app: Flask,
):
oauth_provider = MagicMock()
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
mock_get_oauth_providers.return_value = {"github": oauth_provider}
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
OAuthLogin().get("github")
oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone="Asia/Shanghai",
language="zh-Hans",
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_registers_with_browser_timezone(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_link_account,
app: Flask,
):
account = MagicMock()
mock_register_service.register.return_value = account
mock_feature_service.get_system_features.return_value.is_allow_register = True
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai")
assert result is account
assert oauth_new_user is True
mock_register_service.register.assert_called_once_with(
email="user@example.com",
name="Test User",
password=None,
open_id="github-123",
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
)
mock_link_account.assert_called_once_with("github", "github-123", account)
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_prefers_state_language_over_accept_language(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_link_account,
app: Flask,
):
account = MagicMock()
mock_register_service.register.return_value = account
mock_feature_service.get_system_features.return_value.is_allow_register = True
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
_generate_account("github", user_info, language="zh-Hans")
mock_register_service.register.assert_called_once_with(
email="user@example.com",
name="Test User",
password=None,
open_id="github-123",
provider="github",
language="zh-Hans",
timezone=None,
)
mock_link_account.assert_called_once_with("github", "github-123", account)
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_rejects_new_user_when_registration_disabled(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_config,
app: Flask,
):
mock_feature_service.get_system_features.return_value.is_allow_register = False
mock_config.BILLING_ENABLED = False
user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com")
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
with pytest.raises(AccountRegisterError):
_generate_account("github", user_info)
mock_register_service.register.assert_not_called()

View File

@@ -1,6 +1,6 @@
import pytest
from libs.oauth import OAuth
from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state
def test_oauth_base_methods_raise_not_implemented():
@@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented():
with pytest.raises(NotImplementedError):
oauth._transform_user_info({})
def test_oauth_state_round_trips_invite_token_timezone_and_language():
state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans")
assert decode_oauth_state(state) == {
"invite_token": "invite-123",
"timezone": "Asia/Shanghai",
"language": "zh-Hans",
}
def test_oauth_state_returns_empty_payload_for_invalid_state():
assert decode_oauth_state("invalid-state") == {}

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
class BaseOAuthTest:
@@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest):
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
("invite_token", "timezone", "language", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
(None, None, None, None),
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
("", None, None, None),
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
(None, None, "zh-Hans", {"language": "zh-Hans"}),
(
"test_invite_token",
"Asia/Shanghai",
"zh-Hans",
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
def test_should_generate_authorization_url_correctly(
self, oauth, oauth_config, invite_token, timezone, language, expected_state
):
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
@@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest):
assert params["scope"][0] == "user:email"
if expected_state:
assert params["state"][0] == expected_state
assert decode_oauth_state(params["state"][0]) == expected_state
else:
assert "state" not in params
@@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest):
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
("invite_token", "timezone", "language", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
(None, None, None, None),
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
("", None, None, None),
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
(None, None, "zh-Hans", {"language": "zh-Hans"}),
(
"test_invite_token",
"Asia/Shanghai",
"zh-Hans",
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
def test_should_generate_authorization_url_correctly(
self, oauth, oauth_config, invite_token, timezone, language, expected_state
):
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
@@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest):
assert params["scope"][0] == "openid email"
if expected_state:
assert params["state"][0] == expected_state
assert decode_oauth_state(params["state"][0]) == expected_state
else:
assert "state" not in params

View File

@@ -260,7 +260,7 @@ class TestAccountService:
assert result.interface_theme == "light"
assert result.password is not None
assert result.password_salt is not None
assert result.timezone is not None
assert result.timezone == "America/New_York"
# Verify database operations
mock_db_dependencies["db"].session.add.assert_called_once()
@@ -271,7 +271,28 @@ class TestAccountService:
assert added_account.interface_theme == "light"
assert added_account.password is not None
assert added_account.password_salt is not None
assert added_account.timezone is not None
assert added_account.timezone == "America/New_York"
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_create_account_uses_explicit_timezone(
self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
):
"""Test account creation prefers explicit browser timezone."""
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_password_dependencies["hash_password"].return_value = b"hashed_password"
result = AccountService.create_account(
email="test@example.com",
name="Test User",
interface_language="en-US",
password="password123",
timezone="Asia/Shanghai",
)
assert result.timezone == "Asia/Shanghai"
added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
assert added_account.timezone == "Asia/Shanghai"
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_create_account_registration_disabled(self, mock_external_service_dependencies):
@@ -1221,6 +1242,7 @@ class TestRegisterService:
interface_language="en-US",
password="password123",
is_setup=False,
timezone=None,
)
mock_create_tenant.assert_called_once_with("Test User's Workspace")
mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner")

View File

@@ -13,6 +13,7 @@ import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common'
import { encryptVerificationCode } from '@/utils/encryption'
import { getBrowserTimezone } from '@/utils/timezone'
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
export default function CheckCode() {
@@ -39,7 +40,13 @@ export default function CheckCode() {
return
}
setIsLoading(true)
const ret = await emailLoginWithCode({ email, code: encryptVerificationCode(code), token, language })
const ret = await emailLoginWithCode({
email,
code: encryptVerificationCode(code),
token,
language,
timezone: getBrowserTimezone(),
})
if (ret.result === 'success') {
// Track login success event
trackEvent('user_login_success', {

View File

@@ -0,0 +1,86 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useSearchParams } from '@/next/navigation'
import { getBrowserTimezone } from '@/utils/timezone'
import SocialAuth from '../social-auth'
vi.mock('@/next/navigation', () => ({
useSearchParams: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
}))
const mockUseSearchParams = vi.mocked(useSearchParams)
const mockUseLocale = vi.mocked(useLocale)
const mockGetBrowserTimezone = vi.mocked(getBrowserTimezone)
describe('SocialAuth', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType<typeof useSearchParams>)
mockUseLocale.mockReturnValue('zh-Hans')
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
})
describe('Rendering', () => {
it('should render oauth provider links', () => {
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toBeInTheDocument()
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toBeInTheDocument()
})
})
describe('OAuth params', () => {
it('should include browser timezone and locale in oauth links', () => {
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
'href',
expect.stringContaining('timezone=Asia%2FShanghai'),
)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
'href',
expect.stringContaining('language=zh-Hans'),
)
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
'href',
expect.stringContaining('timezone=Asia%2FShanghai'),
)
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
'href',
expect.stringContaining('language=zh-Hans'),
)
})
it('should preserve invite token when adding timezone', () => {
mockUseSearchParams.mockReturnValue(
new URLSearchParams('invite_token=invite-123') as unknown as ReturnType<typeof useSearchParams>,
)
render(<SocialAuth />)
const githubLink = screen.getByRole('link', { name: 'login.withGitHub' })
expect(githubLink).toHaveAttribute('href', expect.stringContaining('invite_token=invite-123'))
expect(githubLink).toHaveAttribute('href', expect.stringContaining('timezone=Asia%2FShanghai'))
expect(githubLink).toHaveAttribute('href', expect.stringContaining('language=zh-Hans'))
})
})
describe('Edge Cases', () => {
it('should omit timezone when browser timezone is unavailable', () => {
mockGetBrowserTimezone.mockReturnValue(undefined)
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' }).getAttribute('href')).not.toContain('timezone=')
})
})
})

View File

@@ -2,8 +2,10 @@ import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { useTranslation } from 'react-i18next'
import { API_PREFIX } from '@/config'
import { useLocale } from '@/context/i18n'
import { useSearchParams } from '@/next/navigation'
import { getPurifyHref } from '@/utils'
import { getBrowserTimezone } from '@/utils/timezone'
import style from '../page.module.css'
type SocialAuthProps = {
@@ -13,11 +15,19 @@ type SocialAuthProps = {
export default function SocialAuth(props: SocialAuthProps) {
const { t } = useTranslation()
const searchParams = useSearchParams()
const locale = useLocale()
const getOAuthLink = (href: string) => {
const url = getPurifyHref(`${API_PREFIX}${href}`)
if (searchParams.has('invite_token'))
return `${url}?${searchParams.toString()}`
const params = new URLSearchParams(searchParams.toString())
const timezone = getBrowserTimezone()
if (timezone)
params.set('timezone', timezone)
params.set('language', locale)
const query = params.toString()
if (query)
return `${url}?${query}`
return url
}

View File

@@ -0,0 +1,139 @@
import type { MockedFunction } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { activateMember } from '@/service/common'
import { useInvitationCheck } from '@/service/use-common'
import { getBrowserTimezone } from '@/utils/timezone'
import InviteSettingsPage from '../page'
vi.mock('@tanstack/react-query', async () => {
const actual = await vi.importActual<typeof import('@tanstack/react-query')>('@tanstack/react-query')
return {
...actual,
useSuspenseQuery: vi.fn(() => ({
data: {
branding: {
enabled: true,
},
},
})),
}
})
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/i18n-config', () => ({
i18n: {
defaultLocale: 'en-US',
},
setLocaleOnClient: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
useSearchParams: vi.fn(),
}))
vi.mock('@/service/common', () => ({
activateMember: vi.fn(),
}))
vi.mock('@/service/use-common', () => ({
useInvitationCheck: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
timezones: [
{ value: 'Asia/Shanghai', name: 'Asia/Shanghai' },
{ value: 'America/Los_Angeles', name: 'America/Los_Angeles' },
],
}))
vi.mock('../utils/post-login-redirect', () => ({
resolvePostLoginRedirect: vi.fn(() => null),
}))
const mockReplace = vi.fn()
const mockRefetch = vi.fn()
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
const mockActivateMember = activateMember as unknown as MockedFunction<typeof activateMember>
const mockUseInvitationCheck = useInvitationCheck as unknown as MockedFunction<typeof useInvitationCheck>
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
describe('InviteSettingsPage', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseLocale.mockReturnValue('zh-Hans')
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
mockUseSearchParams.mockReturnValue(
new URLSearchParams('invite_token=invite-token') as unknown as ReturnType<typeof useSearchParams>,
)
mockUseInvitationCheck.mockReturnValue({
data: {
is_valid: true,
data: {
workspace_name: 'Acme',
workspace_id: 'workspace-id',
email: 'invitee@example.com',
},
},
refetch: mockRefetch,
} as unknown as ReturnType<typeof useInvitationCheck>)
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
mockActivateMember.mockResolvedValue({ result: 'success' })
})
describe('Activation payload', () => {
it('should default language to the current UI locale', async () => {
render(<InviteSettingsPage />)
fireEvent.change(screen.getByLabelText('login.name'), {
target: { value: 'Invitee' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
await waitFor(() => {
expect(mockActivateMember).toHaveBeenCalledWith({
url: '/activate',
body: {
token: 'invite-token',
name: 'Invitee',
interface_language: 'zh-Hans',
timezone: 'Asia/Shanghai',
},
})
})
})
it('should fall back to configured default locale when current locale is unsupported', async () => {
mockUseLocale.mockReturnValue('unsupported-locale' as ReturnType<typeof useLocale>)
render(<InviteSettingsPage />)
fireEvent.change(screen.getByLabelText('login.name'), {
target: { value: 'Invitee' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
await waitFor(() => {
expect(mockActivateMember).toHaveBeenCalledWith({
url: '/activate',
body: {
token: 'invite-token',
name: 'Invitee',
interface_language: 'en-US',
timezone: 'Asia/Shanghai',
},
})
})
})
})
})

View File

@@ -11,14 +11,15 @@ import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
import { LICENSE_LINK } from '@/constants/link'
import { setLocaleOnClient } from '@/i18n-config'
import { languages, LanguagesSupported } from '@/i18n-config/language'
import { useLocale } from '@/context/i18n'
import { i18n, setLocaleOnClient } from '@/i18n-config'
import { languages } from '@/i18n-config/language'
import Link from '@/next/link'
import { useRouter, useSearchParams } from '@/next/navigation'
import { activateMember } from '@/service/common'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { useInvitationCheck } from '@/service/use-common'
import { timezones } from '@/utils/timezone'
import { getBrowserTimezone, timezones } from '@/utils/timezone'
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
type LanguageSelectOption = {
@@ -43,15 +44,23 @@ const TIMEZONE_OPTIONS: TimezoneSelectOption[] = timezones.map(item => ({
name: item.name,
}))
const getInitialLanguage = (locale: Locale): Locale => {
if (LANGUAGE_OPTIONS.some(item => item.value === locale))
return locale
return i18n.defaultLocale
}
export default function InviteSettingsPage() {
const { t } = useTranslation()
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
const router = useRouter()
const searchParams = useSearchParams()
const token = decodeURIComponent(searchParams.get('invite_token') as string)
const locale = useLocale()
const [name, setName] = useState('')
const [language, setLanguage] = useState(LanguagesSupported[0])
const [timezone, setTimezone] = useState(() => Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles')
const [language, setLanguage] = useState(() => getInitialLanguage(locale))
const [timezone, setTimezone] = useState(() => getBrowserTimezone() || 'America/Los_Angeles')
const selectedLanguage = LANGUAGE_OPTIONS.find(item => item.value === language)
const selectedTimezone = TIMEZONE_OPTIONS.find(item => item.value === timezone)

View File

@@ -0,0 +1,85 @@
import type { MockedFunction } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { getBrowserTimezone } from '@/utils/timezone'
import ChangePasswordForm from '../page'
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
useSearchParams: vi.fn(),
}))
vi.mock('@/service/use-common', () => ({
useMailRegister: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
}))
vi.mock('@/utils/gtag', () => ({
sendGAEvent: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/utils/create-app-tracking', () => ({
rememberCreateAppExternalAttribution: vi.fn(),
}))
const mockRegister = vi.fn()
const mockReplace = vi.fn()
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
const mockUseMailRegister = useMailRegister as unknown as MockedFunction<typeof useMailRegister>
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
describe('Signup Set Password Page', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseLocale.mockReturnValue('zh-Hans')
mockUseSearchParams.mockReturnValue(new URLSearchParams('token=register-token') as unknown as ReturnType<typeof useSearchParams>)
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
mockUseMailRegister.mockReturnValue({
mutateAsync: mockRegister,
isPending: false,
} as unknown as ReturnType<typeof useMailRegister>)
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
mockRegister.mockResolvedValue({ result: 'fail', data: {} })
})
describe('Registration payload', () => {
it('should submit locale and browser timezone when setting password', async () => {
render(<ChangePasswordForm />)
fireEvent.change(screen.getByLabelText('common.account.newPassword'), {
target: { value: 'ValidPass123!' },
})
fireEvent.change(screen.getByLabelText('common.account.confirmPassword'), {
target: { value: 'ValidPass123!' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.changePasswordBtn' }))
await waitFor(() => {
expect(mockRegister).toHaveBeenCalledWith({
token: 'register-token',
new_password: 'ValidPass123!',
password_confirm: 'ValidPass123!',
language: 'zh-Hans',
timezone: 'Asia/Shanghai',
})
})
})
})
})

View File

@@ -9,10 +9,12 @@ import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Input from '@/app/components/base/input'
import { validPassword } from '@/config'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
import { getBrowserTimezone } from '@/utils/timezone'
const parseUtmInfo = () => {
const utmInfoStr = Cookies.get('utm_info')
@@ -32,6 +34,7 @@ const ChangePasswordForm = () => {
const router = useRouter()
const searchParams = useSearchParams()
const token = decodeURIComponent(searchParams.get('token') || '')
const locale = useLocale()
const [password, setPassword] = useState('')
const [confirmPassword, setConfirmPassword] = useState('')
@@ -65,6 +68,8 @@ const ChangePasswordForm = () => {
token,
new_password: password,
password_confirm: confirmPassword,
language: locale,
timezone: getBrowserTimezone(),
})
const { result } = res as MailRegisterResponse
if (result === 'success') {
@@ -88,7 +93,7 @@ const ChangePasswordForm = () => {
catch (error) {
console.error(error)
}
}, [password, token, valid, confirmPassword, register])
}, [password, token, valid, confirmPassword, register, locale])
return (
<div className={

View File

@@ -339,7 +339,13 @@ export const uploadRemoteFileInfo = (url: string, isPublic?: boolean, silent?: b
export const sendEMailLoginCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string }> =>
post<CommonResponse & { data: string }>('/email-code-login', { body: { email, language } })
export const emailLoginWithCode = (data: { email: string, code: string, token: string, language: string }): Promise<LoginResponse> =>
export const emailLoginWithCode = (data: {
email: string
code: string
token: string
language: string
timezone?: string
}): Promise<LoginResponse> =>
post<LoginResponse>('/email-code-login/validity', { body: data })
export const sendResetPasswordCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string, message?: string, code?: string }> =>

View File

@@ -178,7 +178,13 @@ export type MailRegisterResponse = { result: string, data: {} }
export const useMailRegister = () => {
return useMutation({
mutationKey: [NAME_SPACE, 'mail-register'],
mutationFn: (body: { token: string, new_password: string, password_confirm: string }) => {
mutationFn: (body: {
token: string
new_password: string
password_confirm: string
language?: string
timezone?: string
}) => {
return post<MailRegisterResponse>('/email-register', { body })
},
})

View File

@@ -5,3 +5,10 @@ type Item = {
name: string
}
export const timezones: Item[] = tz
export const getBrowserTimezone = () => {
if (typeof Intl === 'undefined')
return undefined
return Intl.DateTimeFormat().resolvedOptions().timeZone || undefined
}