mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 03:00:16 -04:00
refactor(api): migrate tools, account, workflow and plugin services to SQLAlchemy 2.0 (#34966)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -682,7 +682,7 @@ class ToolManager:
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids))))
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
|
||||
@@ -809,11 +809,11 @@ class AccountService:
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
|
||||
@@ -13,6 +13,7 @@ import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
@@ -66,7 +67,7 @@ class PluginMigration:
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
@@ -123,9 +124,12 @@ class PluginMigration:
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(Tenant.id)).where(
|
||||
Tenant.created_at.between(current_time, current_time + test_interval)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
@@ -147,8 +151,8 @@ class PluginMigration:
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
rs = session.execute(
|
||||
select(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
@@ -235,7 +239,7 @@ class PluginMigration:
|
||||
Extract tool tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
result.append(ToolProviderID(row.provider).plugin_id)
|
||||
@@ -249,7 +253,7 @@ class PluginMigration:
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
|
||||
rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
graph = row.graph_dict
|
||||
@@ -272,7 +276,7 @@ class PluginMigration:
|
||||
Extract app tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
apps = session.query(App).where(App.tenant_id == tenant_id).all()
|
||||
apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
|
||||
if not apps:
|
||||
return []
|
||||
|
||||
@@ -280,7 +284,7 @@ class PluginMigration:
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
agent_config = row.agent_mode_dict
|
||||
|
||||
@@ -283,7 +283,9 @@ class RagPipelineDslService:
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
if not dataset:
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
datasets = self._session.scalars(
|
||||
select(Dataset).where(Dataset.tenant_id == account.current_tenant_id)
|
||||
).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
generate_name = generate_incremental_name(names, name)
|
||||
dataset = Dataset(
|
||||
@@ -303,8 +305,8 @@ class RagPipelineDslService:
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = self._session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@@ -312,7 +314,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@@ -440,8 +442,8 @@ class RagPipelineDslService:
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = self._session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@@ -449,7 +451,7 @@ class RagPipelineDslService:
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@@ -591,14 +593,14 @@ class RagPipelineDslService:
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
workflow = self._session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# create draft workflow if not found
|
||||
@@ -665,14 +667,12 @@ class RagPipelineDslService:
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
.where(
|
||||
workflow = self._session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
@@ -904,15 +904,16 @@ class RagPipelineDslService:
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
self._session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
if self._session.scalar(
|
||||
select(Dataset).where(
|
||||
Dataset.name == rag_pipeline_dataset_create_entity.name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
):
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
|
||||
@@ -19,7 +19,7 @@ from graphon.variables.segments import (
|
||||
)
|
||||
from graphon.variables.types import SegmentType
|
||||
from graphon.variables.utils import dumps_with_segments
|
||||
from sqlalchemy import Engine, orm, select
|
||||
from sqlalchemy import Engine, delete, orm, select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -222,11 +222,10 @@ class WorkflowDraftVariableService:
|
||||
)
|
||||
|
||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||
return (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
return self._session.scalar(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(WorkflowDraftVariable.id == variable_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_draft_variables_by_selectors(
|
||||
@@ -254,20 +253,21 @@ class WorkflowDraftVariableService:
|
||||
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||
# combined using `UNION` to fetch all rows.
|
||||
# Benchmarking indicates that both approaches yield comparable performance.
|
||||
query = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.options(
|
||||
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
|
||||
WorkflowDraftVariableFile.upload_file
|
||||
return list(
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(
|
||||
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
|
||||
WorkflowDraftVariableFile.upload_file
|
||||
)
|
||||
)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
or_(*ors),
|
||||
)
|
||||
)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
or_(*ors),
|
||||
)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
def list_variables_without_values(
|
||||
self, app_id: str, page: int, limit: int, user_id: str
|
||||
@@ -277,18 +277,21 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
total = None
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
base_stmt = select(WorkflowDraftVariable).where(*criteria)
|
||||
if page == 1:
|
||||
total = query.count()
|
||||
variables = (
|
||||
# Do not load the `value` field
|
||||
query.options(
|
||||
orm.defer(WorkflowDraftVariable.value, raiseload=True),
|
||||
from sqlalchemy import func as sa_func
|
||||
|
||||
total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery()))
|
||||
variables = list(
|
||||
self._session.scalars(
|
||||
# Do not load the `value` field
|
||||
base_stmt.options(
|
||||
orm.defer(WorkflowDraftVariable.value, raiseload=True),
|
||||
)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset((page - 1) * limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return WorkflowDraftVariableList(variables=variables, total=total)
|
||||
@@ -299,11 +302,13 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
variables = (
|
||||
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
.all()
|
||||
variables = list(
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(*criteria)
|
||||
.order_by(WorkflowDraftVariable.created_at.desc())
|
||||
)
|
||||
)
|
||||
return WorkflowDraftVariableList(variables=variables)
|
||||
|
||||
@@ -326,8 +331,8 @@ class WorkflowDraftVariableService:
|
||||
return self._get_variable(app_id, node_id, name, user_id=user_id)
|
||||
|
||||
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
return self._session.scalar(
|
||||
select(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
@@ -335,7 +340,6 @@ class WorkflowDraftVariableService:
|
||||
WorkflowDraftVariable.name == name,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_variable(
|
||||
@@ -488,20 +492,20 @@ class WorkflowDraftVariableService:
|
||||
self._session.delete(variable)
|
||||
|
||||
def delete_user_workflow_variables(self, app_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def delete_app_workflow_variables(self, app_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app_id)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]):
|
||||
@@ -540,14 +544,14 @@ class WorkflowDraftVariableService:
|
||||
return self._delete_node_variables(app_id, node_id, user_id=user_id)
|
||||
|
||||
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
self._session.execute(
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
|
||||
@@ -588,13 +592,11 @@ class WorkflowDraftVariableService:
|
||||
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
|
||||
|
||||
if conv_id is not None:
|
||||
conversation = (
|
||||
self._session.query(Conversation)
|
||||
.where(
|
||||
conversation = self._session.scalar(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conv_id,
|
||||
Conversation.app_id == workflow.app_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB).
|
||||
if conversation is not None:
|
||||
|
||||
@@ -1512,14 +1512,12 @@ class WorkflowService:
|
||||
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
tool_provider = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
tool_provider = session.scalar(
|
||||
select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == workflow.tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow.app_id,
|
||||
WorkflowToolProvider.version == workflow.version,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_provider:
|
||||
|
||||
@@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql():
|
||||
for scheme in ("postgresql", "mysql"):
|
||||
session = Mock()
|
||||
session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")]
|
||||
session.query.return_value.where.return_value.all.return_value = provider_records
|
||||
session.scalars.return_value = iter(provider_records)
|
||||
|
||||
with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
|
||||
@@ -247,10 +247,11 @@ workflow:
|
||||
dataset_mock = Mock()
|
||||
dataset_mock.id = "d1"
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
session.scalars.return_value.all.return_value = []
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content)
|
||||
@@ -320,6 +321,7 @@ workflow:
|
||||
dataset_mock.id = "d1"
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", return_value=Mock(id="b1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
service = RagPipelineDslService(session=Mock())
|
||||
# Mocking self._session.scalar for the pipeline lookup
|
||||
@@ -406,12 +408,14 @@ def test_create_or_update_pipeline_create_new(mocker) -> None:
|
||||
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock())
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
|
||||
pipeline_instance = pipeline_cls.return_value
|
||||
pipeline_instance.tenant_id = "t1"
|
||||
pipeline_instance.id = "p1"
|
||||
pipeline_instance.name = "P"
|
||||
pipeline_instance.is_published = False
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = service._create_or_update_pipeline(pipeline=None, data=data, account=account, dependencies=[])
|
||||
|
||||
@@ -447,8 +451,7 @@ def test_export_rag_pipeline_dsl_with_workflow(mocker) -> None:
|
||||
workflow.rag_pipeline_variables = []
|
||||
workflow.to_dict.return_value = {"graph": {"nodes": []}}
|
||||
|
||||
# Mocking single .where() call
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@@ -550,7 +553,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@@ -568,7 +571,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None:
|
||||
def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.first.return_value = Mock()
|
||||
session.scalar.return_value = Mock()
|
||||
create_entity = RagPipelineDatasetCreateEntity(
|
||||
name="Existing Name",
|
||||
description="",
|
||||
@@ -584,8 +587,8 @@ def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None:
|
||||
def test_create_rag_pipeline_dataset_generates_name_when_missing(mocker) -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.query.return_value.filter_by.return_value.all.return_value = [Mock(name="Untitled")]
|
||||
session.scalar.return_value = None
|
||||
session.scalars.return_value.all.return_value = [Mock(name="Untitled")]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="Untitled 2")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", Mock(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch.object(
|
||||
@@ -632,7 +635,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch.object(service, "encrypt_dataset_id", side_effect=lambda dataset_id, tenant_id: f"enc-{dataset_id}")
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
@@ -727,7 +730,7 @@ def test_create_or_update_pipeline_decrypts_knowledge_retrieval_dataset_ids(mock
|
||||
},
|
||||
}
|
||||
draft_workflow = Mock(id="wf1")
|
||||
session.query.return_value.where.return_value.first.return_value = draft_workflow
|
||||
session.scalar.return_value = draft_workflow
|
||||
mocker.patch.object(service, "decrypt_dataset_id", side_effect=["d1", None])
|
||||
|
||||
result = service._create_or_update_pipeline(pipeline=pipeline, data=data, account=account)
|
||||
@@ -743,7 +746,8 @@ def test_create_or_update_pipeline_creates_draft_when_missing(mocker) -> None:
|
||||
account = Mock(id="u1", current_tenant_id="t1")
|
||||
pipeline = Mock(id="p1", tenant_id="t1", name="N", description="D")
|
||||
data = {"rag_pipeline": {"name": "N2", "description": "D2"}, "workflow": {"graph": {"nodes": []}}}
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
workflow_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow")
|
||||
workflow_cls.return_value.id = "wf-new"
|
||||
|
||||
@@ -817,7 +821,7 @@ def test_import_rag_pipeline_fails_for_non_string_version_type() -> None:
|
||||
def test_append_workflow_export_data_raises_when_draft_workflow_missing() -> None:
|
||||
session = cast(MagicMock, Mock())
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
|
||||
service._append_workflow_export_data(export_data={}, pipeline=Mock(tenant_id="t1"), include_secret=False)
|
||||
@@ -841,7 +845,7 @@ def test_append_workflow_export_data_keeps_secret_fields_when_include_secret_tru
|
||||
]
|
||||
}
|
||||
}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@@ -1003,7 +1007,8 @@ def test_import_rag_pipeline_sets_default_version_and_kind(mocker) -> None:
|
||||
)
|
||||
dataset = Mock(id="d1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset)
|
||||
session.query.return_value.filter_by.return_value.all.return_value = []
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
session.scalars.return_value.all.return_value = []
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="P")
|
||||
|
||||
result = service.import_rag_pipeline(
|
||||
@@ -1061,7 +1066,7 @@ def test_append_workflow_export_data_skips_empty_node_data(mocker) -> None:
|
||||
workflow = Mock()
|
||||
workflow.graph_dict = {"nodes": []}
|
||||
workflow.to_dict.return_value = {"graph": {"nodes": [{"data": {}}, {}]}}
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
session.scalar.return_value = workflow
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies",
|
||||
return_value=[],
|
||||
@@ -1246,11 +1251,12 @@ def test_create_or_update_pipeline_saves_dependencies_to_redis(mocker) -> None:
|
||||
account = Mock(id="u1", current_tenant_id="t1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock(id="wf-1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline")
|
||||
pipeline = pipeline_cls.return_value
|
||||
pipeline.tenant_id = "t1"
|
||||
pipeline.id = "p1"
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
setex = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.setex")
|
||||
dependency = PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
|
||||
@@ -969,8 +969,7 @@ class TestWorkflowService:
|
||||
# 1. Workflow exists
|
||||
# 2. No app is currently using it
|
||||
# 3. Not published as a tool
|
||||
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider
|
||||
mock_session.scalar.side_effect = [mock_workflow, None, None] # workflow, no app using it, no tool provider
|
||||
|
||||
with patch("services.workflow_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
@@ -1045,8 +1044,7 @@ class TestWorkflowService:
|
||||
mock_tool_provider = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider
|
||||
mock_session.scalar.side_effect = [mock_workflow, None, mock_tool_provider] # workflow, no app, tool provider
|
||||
|
||||
with patch("services.workflow_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user