diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index d6e3ebfbcd..ed0d490aad 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -6,7 +6,7 @@ from flask import current_app, request from flask_login import user_logged_in from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from libs.login import current_user @@ -33,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID try: - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: user_model = None if is_anonymous: @@ -56,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: session_id=user_id, ) session.add(user_model) - session.commit() + session.flush() session.refresh(user_model) except Exception: diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index edbf011656..8c9a3eb5e9 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,7 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound import services @@ -116,7 +116,7 @@ class ConversationApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: pagination = ConversationService.pagination_by_last_id( session=session, app_model=app_model, diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 1759075139..d7992a2a3a 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource): # get paginate workflow app logs workflow_app_service = WorkflowAppService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( session=session, app_model=app_model, diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index eac57fe4b7..957d7fbd9b 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -41,15 +41,15 @@ class TestGetUser: """Test get_user function""" @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.return_value = mock_user # Act @@ -61,17 +61,17 @@ class TestGetUser: mock_session.get.assert_called_once() @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange mock_user = MagicMock() mock_user.session_id = "anonymous_session" mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session # non-anonymous path uses session.get(); anonymous uses session.scalar() mock_session.get.return_value = mock_user @@ -83,13 +83,13 @@ class TestGetUser: assert result == mock_user @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user @@ -101,21 +101,20 @@ class TestGetUser: # Assert assert result == mock_new_user mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_use_default_session_id_when_user_id_none( - self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test using default session ID when user_id is None""" # Arrange mock_user = MagicMock() mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session # When user_id is None, is_anonymous=True, so session.scalar() is used mock_session.scalar.return_value = mock_user @@ -127,15 +126,13 @@ class TestGetUser: assert result == mock_user @patch("controllers.inner_api.plugin.wraps.EndUser") - @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_raise_error_on_database_exception( - self, mock_db, mock_session_class, mock_enduser_class, app: Flask - ): + def test_should_raise_error_on_database_exception(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): """Test raising ValueError when database operation fails""" # Arrange mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.get.side_effect = Exception("Database error") # Act & Assert diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 81c45dcdb7..dbd06677d8 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -433,13 +433,20 @@ class TestConversationApiController: handler(api, app_model=app_model, end_user=end_user) def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: - class _SessionStub: + class _BeginStub: def __enter__(self): return SimpleNamespace() def __exit__(self, exc_type, exc, tb): return False + class _SessionMakerStub: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return _BeginStub() + monkeypatch.setattr( ConversationService, "pagination_by_last_id", @@ -447,7 +454,7 @@ class TestConversationApiController: ) conversation_module = sys.modules["controllers.service_api.app.conversation"] monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub()) + monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub) api = ConversationApi() handler = _unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index b1f036c6f3..cfa21bf2dd 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -470,16 +470,23 @@ class TestWorkflowTaskStopApi: class TestWorkflowAppLogApi: def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: - class _SessionStub: + class _BeginStub: def __enter__(self): return SimpleNamespace() def __exit__(self, exc_type, exc, tb): return False + class _SessionMakerStub: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return _BeginStub() + workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub()) + monkeypatch.setattr(workflow_module, "sessionmaker", _SessionMakerStub) monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", @@ -635,11 +642,14 @@ class TestWorkflowAppLogApiGet: mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination mock_wf_svc_cls.return_value = mock_svc_instance - # Mock Session context manager + # Mock sessionmaker(...).begin() context manager mock_session = Mock() mock_db.engine = Mock() - mock_session.__enter__ = Mock(return_value=mock_session) - mock_session.__exit__ = Mock(return_value=False) + mock_begin = Mock() + mock_begin.__enter__ = Mock(return_value=mock_session) + mock_begin.__exit__ = Mock(return_value=False) + mock_session_factory = Mock() + mock_session_factory.begin.return_value = mock_begin from controllers.service_api.app.workflow import WorkflowAppLogApi @@ -647,7 +657,7 @@ class TestWorkflowAppLogApiGet: "/workflows/logs?page=1&limit=20", method="GET", ): - with patch("controllers.service_api.app.workflow.Session", return_value=mock_session): + with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory): api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app)