From 2ce2fbc2d4ad7f56627e12bf65e211bf95f32edd Mon Sep 17 00:00:00 2001 From: Desel72 Date: Sat, 21 Mar 2026 05:54:56 -0500 Subject: [PATCH] =?UTF-8?q?refactor:=20migrate=20workflow=20run=20reposito?= =?UTF-8?q?ry=20unit=20tests=20from=20mocks=20to=20te=E2=80=A6=20(#33843)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ..._sqlalchemy_api_workflow_run_repository.py | 224 +++++++++++++++++- ..._sqlalchemy_api_workflow_run_repository.py | 135 ----------- 2 files changed, 223 insertions(+), 136 deletions(-) delete mode 100644 api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 76e586e65f..c3ed79656f 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,6 +2,7 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import PauseReasonType +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.human_input import ( + BackstageRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, + _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: WorkflowRun.app_id == scope.app_id, ) ) + + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(HumanInputForm).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + ) session.commit() for state_key in scope.state_keys: @@ -504,3 +529,200 @@ class TestDeleteWorkflowPause: with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"): repository.delete_workflow_pause(pause_entity=pause_entity) + + +class TestPrivateWorkflowPauseEntity: + """Integration tests for _PrivateWorkflowPauseEntity using real DB models.""" + + def test_properties( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Entity properties delegate to the persisted WorkflowPause model.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(pause.state_object_key) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + + assert entity.id == pause.id + assert entity.workflow_execution_id == workflow_run.id + assert entity.resumed_at is None + + def test_get_state( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state loads state data from storage using the persisted state_object_key.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result = entity.get_state() + + assert result == expected_state + + def test_get_state_caching( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state caches the result so storage is only accessed once.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result1 = entity.get_state() + # Delete from storage to prove second call uses cache + storage.delete(state_key) + test_scope.state_keys.discard(state_key) + result2 = entity.get_state() + + assert result1 == expected_state + assert result2 == expected_state + + +class TestBuildHumanInputRequiredReason: + """Integration tests for _build_human_input_required_reason using real DB models.""" + + def test_prefers_backstage_token_when_available( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use backstage token when multiple recipient types may exist.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + access_token = secrets.token_urlsafe(8) + recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=access_token, + ) + db_session_with_containers.add(recipient) + db_session_with_containers.flush() + + # Create a pause so the reason has a valid pause_id + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + # Refresh to ensure we have DB-round-tripped objects + db_session_with_containers.refresh(form_model) + db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(recipient) + + reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + + assert isinstance(reason, HumanInputRequired) + assert reason.form_token == access_token + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py deleted file mode 100644 index 3707ed90be..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Unit tests for non-SQL helper logic in workflow run repository.""" - -import secrets -from datetime import UTC, datetime -from unittest.mock import Mock, patch - -import pytest - -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType -from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason -from repositories.sqlalchemy_api_workflow_run_repository import ( - _build_human_input_required_reason, - _PrivateWorkflowPauseEntity, -) - - -@pytest.fixture -def sample_workflow_pause() -> Mock: - """Create a sample WorkflowPause model.""" - pause = Mock(spec=WorkflowPauseModel) - pause.id = "pause-123" - pause.workflow_id = "workflow-123" - pause.workflow_run_id = "workflow-run-123" - pause.state_object_key = "workflow-state-123.json" - pause.resumed_at = None - pause.created_at = datetime.now(UTC) - return pause - - -class TestPrivateWorkflowPauseEntity: - """Test _PrivateWorkflowPauseEntity class.""" - - def test_properties(self, sample_workflow_pause: Mock) -> None: - """Test entity properties.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - - # Assert - assert entity.id == sample_workflow_pause.id - assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id - assert entity.resumed_at == sample_workflow_pause.resumed_at - - def test_get_state(self, sample_workflow_pause: Mock) -> None: - """Test getting state from storage.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result = entity.get_state() - - # Assert - assert result == expected_state - mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key) - - def test_get_state_caching(self, sample_workflow_pause: Mock) -> None: - """Test state caching in get_state method.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result1 = entity.get_state() - result2 = entity.get_state() - - # Assert - assert result1 == expected_state - assert result2 == expected_state - mock_storage.load.assert_called_once() - - -class TestBuildHumanInputRequiredReason: - """Test helper that builds HumanInputRequired pause reasons.""" - - def test_prefers_backstage_token_when_available(self) -> None: - """Use backstage token when multiple recipient types may exist.""" - # Arrange - expiration_time = datetime.now(UTC) - form_definition = FormDefinition( - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "Alice"}, - node_title="Ask Name", - display_in_ui=True, - ) - form_model = HumanInputForm( - id="form-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_run_id="run-1", - node_id="node-1", - form_definition=form_definition.model_dump_json(), - rendered_content="rendered", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - reason_model = WorkflowPauseReason( - pause_id="pause-1", - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id="form-1", - node_id="node-1", - message="", - ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - - # Act - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) - - # Assert - assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token - assert reason.node_title == "Ask Name" - assert reason.form_content == "content" - assert reason.inputs[0].output_variable_name == "name" - assert reason.actions[0].id == "approve"