From bcd738d2e6301e82881ee2da965236c853b65241 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 10 Apr 2026 15:13:59 +0800 Subject: [PATCH] fix: fix orm_exc.DetachedInstanceError (#34904) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands/account.py | 67 ++++++++-------- .../console/auth/email_register.py | 24 +++--- .../console/auth/forgot_password.py | 19 +++-- api/controllers/console/auth/oauth.py | 4 +- api/controllers/console/workspace/account.py | 4 +- api/controllers/web/forgot_password.py | 18 ++--- api/services/account_service.py | 18 ++--- .../console/auth/test_email_register.py | 5 +- .../console/auth/test_forgot_password.py | 8 +- .../controllers/console/auth/test_oauth.py | 5 +- .../console/auth/test_password_reset.py | 3 + .../web/test_web_forgot_password.py | 77 ++++++++----------- .../services/test_webapp_auth_service.py | 14 ++-- .../console/test_workspace_account.py | 13 +++- .../services/test_account_service.py | 30 +------- 15 files changed, 142 insertions(+), 167 deletions(-) diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae6..6a2a2e0428 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,6 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker from constants.languages import languages from extensions.ext_database import db @@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account = db.session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + account = db.session.merge(account) + account.email = normalized_new_email + db.session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c5..1fd781b4fc 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 63bc98b53f..ed390a5f89 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22..d31fb4a46c 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 626d330e9d..af25669ae0 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,7 +8,6 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 80c3289fb4..61fd794c22 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/services/account_service.py b/api/services/account_service.py index 4b58b3b697..1f5f81e5bd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -800,19 +801,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1516,8 +1517,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319..320da85b60 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade..d2703ed5cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579..1eabb45422 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e3..50249bcd74 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103..f14b2c0ae5 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d5803..7825f502f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 7f9fe9cbf9..dd643faac9 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -233,15 +233,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index d15074e7a6..eeb5d178ec 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1427,16 +1427,7 @@ class TestRegisterService: mock_tenant.name = "Test Workspace" mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1475,7 +1466,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + mock_lookup.assert_called_once_with("newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1486,13 +1477,7 @@ class TestRegisterService: mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mixed_email = "Invitee@Example.com" - mock_session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1525,7 +1510,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_lookup.assert_called_once_with(mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) @@ -1545,16 +1530,7 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="pending" ) - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = mock_existing_account @@ -1584,7 +1560,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) + mock_lookup.assert_called_once_with("existing@example.com") def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant."""