diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index a58bede8db..9bb0ab6ae2 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -24,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -51,7 +50,7 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index bc2f1c6ecc..021bebceff 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetStrategy: def test_returns_strategy_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() strategy = MagicMock() session.scalar.return_value = strategy - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -30,10 +30,10 @@ class TestGetStrategy: assert result is strategy def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -43,10 +43,10 @@ class TestGetStrategy: class TestChangeStrategy: def test_creates_new_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.return_value = MagicMock() from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -63,11 +63,11 @@ class TestChangeStrategy: session.add.assert_called_once() def test_updates_existing_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.change_strategy( @@ -89,12 +89,11 @@ class TestChangeStrategy: class TestExcludePlugin: def test_creates_default_strategy_when_none_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None with ( p1, - p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, @@ -110,13 +109,13 @@ class TestExcludePlugin: cs.assert_called_once() def test_appends_to_exclude_list_in_exclude_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p-existing"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -128,13 +127,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p-existing", "p-new"] def test_removes_from_include_list_in_partial_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "partial" existing.include_plugins = ["p1", "p2"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -146,12 +145,12 @@ class TestExcludePlugin: assert existing.include_plugins == ["p2"] def test_switches_to_exclude_mode_from_all(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "all" session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -164,13 +163,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p1"] def test_no_duplicate_in_exclude_list(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p1"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all"