refactor(api): use sessionmaker in trigger provider service & dataset… (#34774)

This commit is contained in:
carlos4s
2026-04-08 18:13:38 -05:00
committed by GitHub
parent 3a4756449a
commit 4c70bfa8b8
4 changed files with 22 additions and 27 deletions

View File

@@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import (
DatasetEntity,
@@ -884,7 +884,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer)
return
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Collect all document_ids and batch fetch DatasetDocuments
document_ids = {
doc.metadata["document_id"]
@@ -975,7 +975,6 @@ class DatasetRetrieval:
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
session.commit()
self._send_trace_task(message_id, documents, timer)

View File

@@ -6,7 +6,7 @@ from collections.abc import Mapping
from typing import Any
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@@ -146,7 +146,7 @@ class TriggerProviderService:
"""
try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
with redis_client.lock(lock_key, timeout=20):
@@ -205,7 +205,6 @@ class TriggerProviderService:
subscription.id = subscription_id or str(uuid.uuid4())
session.add(subscription)
session.commit()
return {
"result": "success",
@@ -241,7 +240,7 @@ class TriggerProviderService:
:param expires_at: Optional new expiration timestamp
:return: Success response with updated subscription info
"""
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
@@ -302,8 +301,6 @@ class TriggerProviderService:
if expires_at is not None:
subscription.expires_at = expires_at
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
@@ -404,7 +401,7 @@ class TriggerProviderService:
:param subscription_id: Subscription instance ID
:return: New token info
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not subscription:
@@ -448,7 +445,6 @@ class TriggerProviderService:
# Update credentials
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
subscription.credential_expires_at = refreshed_credentials.expires_at
session.commit()
# Clear cache
cache.delete()
@@ -478,7 +474,7 @@ class TriggerProviderService:
"""
now_ts: int = int(now if now is not None else _time.time())
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
@@ -531,7 +527,6 @@ class TriggerProviderService:
# Persist refreshed properties and expires_at
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
subscription.expires_at = int(refreshed.expires_at)
session.commit()
properties_cache.delete()
logger.info(
@@ -639,7 +634,7 @@ class TriggerProviderService:
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
@@ -683,8 +678,6 @@ class TriggerProviderService:
if enabled is not None:
custom_client.enabled = enabled
session.commit()
return {"result": "success"}
@classmethod
@@ -733,13 +726,12 @@ class TriggerProviderService:
:param provider_id: Provider identifier
:return: Success response
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.commit()
return {"result": "success"}

View File

@@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage:
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
sessionmaker_ctx = MagicMock()
sessionmaker_ctx.begin.return_value = session_ctx
with (
patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())),
patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx),
patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx),
patch.object(retrieval, "_send_trace_task") as mock_trace,
):
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
query.update.assert_called_once()
session.commit.assert_called_once()
mock_trace.assert_called_once()
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:

View File

@@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock:
mock_session_cm.__enter__.return_value = mock_session_instance
mock_session_cm.__exit__.return_value = False
mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm)
mock_begin_cm = MagicMock()
mock_begin_cm.__enter__.return_value = mock_session_instance
mock_begin_cm.__exit__.return_value = False
mock_sessionmaker_instance = MagicMock()
mock_sessionmaker_instance.begin.return_value = mock_begin_cm
mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance)
return mock_session_instance
@@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
# Assert
assert result["result"] == "success"
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type(
@@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
assert subscription.credentials == {"api_key": "new-key"}
assert subscription.credential_expires_at == 100
assert subscription.expires_at == 200
mock_session.commit.assert_called_once()
mock_delete_cache.assert_called_once()
@@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
assert result == {"result": "success", "expires_at": 12345}
assert subscription.credentials == {"access_token": "new"}
assert subscription.credential_expires_at == 12345
mock_session.commit.assert_called_once()
cache.delete.assert_called_once()
@@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
assert result == {"result": "success", "expires_at": 999}
assert subscription.properties == {"p": "new-enc"}
assert subscription.expires_at == 999
mock_session.commit.assert_called_once()
prop_cache.delete.assert_called_once()
@@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
assert fake_model.encrypted_oauth_params == "{}"
assert fake_model.enabled is True
mock_session.add.assert_called_once_with(fake_model)
mock_session.commit.assert_called_once()
def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache(
@@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
assert result == {"result": "success"}
assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"}
cache.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
@@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
# Assert
assert result == {"result": "success"}
mock_session.commit.assert_called_once()
@pytest.mark.parametrize("exists", [True, False])