diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py similarity index 91% rename from api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py rename to api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index a3c0592d76..c1f3122c2b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -1,7 +1,13 @@ +"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints.""" + +from __future__ import annotations + from datetime import datetime from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound import services @@ -38,6 +44,10 @@ def unwrap(func): class TestDraftWorkflowApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_draft_success(self, app): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -200,6 +210,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_iteration_node_success(self, app): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -275,6 +289,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_draft_run_success(self, app): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -337,6 +355,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_execution_not_found(self, app): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -364,45 +386,43 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: - def test_publish_success(self, app): + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + + def test_publish_success(self, app, db_session_with_containers: Session): + from models.dataset import Pipeline + api = PublishedRagPipelineApi() method = unwrap(api.post) - pipeline = MagicMock() + tenant_id = str(uuid4()) + pipeline = Pipeline( + tenant_id=tenant_id, + name="test-pipeline", + description="test", + created_by=str(uuid4()), + ) + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() + db_session_with_containers.expire_all() + user = MagicMock(id="u1") workflow = MagicMock( - id="w1", + id=str(uuid4()), created_at=naive_utc_now(), ) - session = MagicMock() - session.merge.return_value = pipeline - - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - service = MagicMock() service.publish_workflow.return_value = workflow - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -415,6 +435,10 @@ class TestPublishedPipelineApis: class TestMiscApis: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_task_stop(self, app): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -471,6 +495,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_published_run_success(self, app): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_block_config_success(self, app): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_get_published_workflows_success(self, app): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi: service = MagicMock() service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/"), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_patch_success(self, app): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -640,14 +664,6 @@ class TestRagPipelineByIdApi: service = MagicMock() service.update_workflow.return_value = workflow - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - payload = {"marked_name": "test"} with ( @@ -657,14 +673,6 @@ class TestRagPipelineByIdApi: "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", return_value=(user, "t"), ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, @@ -700,24 +708,8 @@ class TestRagPipelineByIdApi: workflow_service = MagicMock() - session = MagicMock() - session_ctx = MagicMock() - session_ctx.__enter__.return_value = session - session_ctx.__exit__.return_value = None - - fake_db = MagicMock() - fake_db.engine = MagicMock() - with ( app.test_request_context("/", method="DELETE"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", - fake_db, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", - return_value=session_ctx, - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", return_value=workflow_service, @@ -725,12 +717,7 @@ class TestRagPipelineByIdApi: ): result = method(api, pipeline, "old-workflow") - workflow_service.delete_workflow.assert_called_once_with( - session=session, - workflow_id="old-workflow", - tenant_id="t1", - ) - session.commit.assert_called_once() + workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) def test_delete_active_workflow_rejected(self, app): @@ -745,6 +732,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_last_run_success(self, app): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + def test_set_datasource_variables_success(self, app): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post)