From 4c70bfa8b8c5eb3d6714d2a3b98adea1dbe900bb Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:13:38 -0500 Subject: [PATCH] =?UTF-8?q?refactor(api):=20use=20sessionmaker=20in=20trig?= =?UTF-8?q?ger=20provider=20service=20&=20dataset=E2=80=A6=20(#34774)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/rag/retrieval/dataset_retrieval.py | 5 ++--- .../trigger/trigger_provider_service.py | 22 ++++++------------- .../rag/retrieval/test_dataset_retrieval.py | 6 +++-- .../services/test_trigger_provider_service.py | 16 ++++++++------ 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4e9b53b83e..0f3351fd68 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( DatasetEntity, @@ -884,7 +884,7 @@ class DatasetRetrieval: self._send_trace_task(message_id, documents, timer) return - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Collect all document_ids and batch fetch DatasetDocuments document_ids = { doc.metadata["document_id"] @@ -975,7 +975,6 @@ class DatasetRetrieval: {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False, ) - session.commit() self._send_trace_task(message_id, documents, timer) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 008d8bdb8a..ae74f7a8cd 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any from sqlalchemy import desc, func -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -146,7 +146,7 @@ class TriggerProviderService: """ try: provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" with redis_client.lock(lock_key, timeout=20): @@ -205,7 +205,6 @@ class TriggerProviderService: subscription.id = subscription_id or str(uuid.uuid4()) session.add(subscription) - session.commit() return { "result": "success", @@ -241,7 +240,7 @@ class TriggerProviderService: :param expires_at: Optional new expiration timestamp :return: Success response with updated subscription info """ - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): @@ -302,8 +301,6 @@ class TriggerProviderService: if expires_at is not None: subscription.expires_at = expires_at - session.commit() - # Clear subscription cache delete_cache_for_subscription( tenant_id=tenant_id, @@ -404,7 +401,7 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: New token info """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() if not subscription: @@ -448,7 +445,6 @@ class TriggerProviderService: # Update credentials subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials))) subscription.credential_expires_at = refreshed_credentials.expires_at - session.commit() # Clear cache cache.delete() @@ -478,7 +474,7 @@ class TriggerProviderService: """ now_ts: int = int(now if now is not None else _time.time()) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: subscription: TriggerSubscription | None = ( session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() ) @@ -531,7 +527,6 @@ class TriggerProviderService: # Persist refreshed properties and expires_at subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) subscription.expires_at = int(refreshed.expires_at) - session.commit() properties_cache.delete() logger.info( @@ -639,7 +634,7 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params custom_client = ( session.query(TriggerOAuthTenantClient) @@ -683,8 +678,6 @@ class TriggerProviderService: if enabled is not None: custom_client.enabled = enabled - session.commit() - return {"result": "success"} @classmethod @@ -733,13 +726,12 @@ class TriggerProviderService: :param provider_id: Provider identifier :return: Success response """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(TriggerOAuthTenantClient).filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, ).delete() - session.commit() return {"result": "success"} diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 40d138df90..b98fec3854 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage: session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False + sessionmaker_ctx = MagicMock() + sessionmaker_ctx.begin.return_value = session_ctx + with ( patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), - patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx), patch.object(retrieval, "_send_trace_task") as mock_trace, ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) query.update.assert_called_once() - session.commit.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index 81a3b181fd..350ff718c1 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock: mock_session_cm.__enter__.return_value = mock_session_instance mock_session_cm.__exit__.return_value = False mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + mock_begin_cm = MagicMock() + mock_begin_cm.__enter__.return_value = mock_session_instance + mock_begin_cm.__exit__.return_value = False + mock_sessionmaker_instance = MagicMock() + mock_sessionmaker_instance.begin.return_value = mock_begin_cm + mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance) return mock_session_instance @@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap # Assert assert result["result"] == "success" mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( @@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( assert subscription.credentials == {"api_key": "new-key"} assert subscription.credential_expires_at == 100 assert subscription.expires_at == 200 - mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() @@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( assert result == {"result": "success", "expires_at": 12345} assert subscription.credentials == {"access_token": "new"} assert subscription.credential_expires_at == 12345 - mock_session.commit.assert_called_once() + cache.delete.assert_called_once() @@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties( assert result == {"result": "success", "expires_at": 999} assert subscription.properties == {"p": "new-enc"} assert subscription.expires_at == 999 - mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() @@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w assert fake_model.encrypted_oauth_params == "{}" assert fake_model.enabled is True mock_session.add.assert_called_once_with(fake_model) - mock_session.commit.assert_called_once() def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( @@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c assert result == {"result": "success"} assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} cache.delete.assert_called_once() - mock_session.commit.assert_called_once() def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( @@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit( # Assert assert result == {"result": "success"} - mock_session.commit.assert_called_once() @pytest.mark.parametrize("exists", [True, False])