mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 12:00:26 -04:00
refactor(api): use sessionmaker in trigger provider service & dataset… (#34774)
This commit is contained in:
@@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -884,7 +884,7 @@ class DatasetRetrieval:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
@@ -975,7 +975,6 @@ class DatasetRetrieval:
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@@ -146,7 +146,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
@@ -205,7 +205,6 @@ class TriggerProviderService:
|
||||
subscription.id = subscription_id or str(uuid.uuid4())
|
||||
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@@ -241,7 +240,7 @@ class TriggerProviderService:
|
||||
:param expires_at: Optional new expiration timestamp
|
||||
:return: Success response with updated subscription info
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
@@ -302,8 +301,6 @@ class TriggerProviderService:
|
||||
if expires_at is not None:
|
||||
subscription.expires_at = expires_at
|
||||
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
@@ -404,7 +401,7 @@ class TriggerProviderService:
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
if not subscription:
|
||||
@@ -448,7 +445,6 @@ class TriggerProviderService:
|
||||
# Update credentials
|
||||
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
subscription.credential_expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
cache.delete()
|
||||
@@ -478,7 +474,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
@@ -531,7 +527,6 @@ class TriggerProviderService:
|
||||
# Persist refreshed properties and expires_at
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
subscription.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
|
||||
logger.info(
|
||||
@@ -639,7 +634,7 @@ class TriggerProviderService:
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
@@ -683,8 +678,6 @@ class TriggerProviderService:
|
||||
if enabled is not None:
|
||||
custom_client.enabled = enabled
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
@@ -733,13 +726,12 @@ class TriggerProviderService:
|
||||
:param provider_id: Provider identifier
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -4909,15 +4909,17 @@ class TestInternalHooksCoverage:
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
sessionmaker_ctx = MagicMock()
|
||||
sessionmaker_ctx.begin.return_value = session_ctx
|
||||
|
||||
with (
|
||||
patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())),
|
||||
patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx),
|
||||
patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx),
|
||||
patch.object(retrieval, "_send_trace_task") as mock_trace,
|
||||
):
|
||||
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
|
||||
|
||||
query.update.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
mock_trace.assert_called_once()
|
||||
|
||||
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:
|
||||
|
||||
@@ -63,6 +63,12 @@ def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
mock_session_cm.__enter__.return_value = mock_session_instance
|
||||
mock_session_cm.__exit__.return_value = False
|
||||
mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm)
|
||||
mock_begin_cm = MagicMock()
|
||||
mock_begin_cm.__enter__.return_value = mock_session_instance
|
||||
mock_begin_cm.__exit__.return_value = False
|
||||
mock_sessionmaker_instance = MagicMock()
|
||||
mock_sessionmaker_instance.begin.return_value = mock_begin_cm
|
||||
mocker.patch("services.trigger.trigger_provider_service.sessionmaker", return_value=mock_sessionmaker_instance)
|
||||
return mock_session_instance
|
||||
|
||||
|
||||
@@ -212,7 +218,6 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
|
||||
# Assert
|
||||
assert result["result"] == "success"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type(
|
||||
@@ -406,7 +411,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
assert subscription.credentials == {"api_key": "new-key"}
|
||||
assert subscription.credential_expires_at == 100
|
||||
assert subscription.expires_at == 200
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
mock_delete_cache.assert_called_once()
|
||||
|
||||
|
||||
@@ -593,7 +598,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
|
||||
assert result == {"result": "success", "expires_at": 12345}
|
||||
assert subscription.credentials == {"access_token": "new"}
|
||||
assert subscription.credential_expires_at == 12345
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
cache.delete.assert_called_once()
|
||||
|
||||
|
||||
@@ -664,7 +669,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
|
||||
assert result == {"result": "success", "expires_at": 999}
|
||||
assert subscription.properties == {"p": "new-enc"}
|
||||
assert subscription.expires_at == 999
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
prop_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
@@ -838,7 +843,6 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
|
||||
assert fake_model.encrypted_oauth_params == "{}"
|
||||
assert fake_model.enabled is True
|
||||
mock_session.add.assert_called_once_with(fake_model)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache(
|
||||
@@ -870,7 +874,6 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
|
||||
assert result == {"result": "success"}
|
||||
assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"}
|
||||
cache.delete.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
|
||||
@@ -921,7 +924,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exists", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user