diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5c6636f31e..c7bb90f019 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes from graphon.variables.segments import StringSegment from graphon.variables.types import SegmentType from graphon.variables.variables import StringVariable -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def setUp(self): self._test_app_id = str(uuid.uuid4()) + self._test_user_id = str(uuid.uuid4()) self._session: Session = db.session() sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, ) conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="int_var", value=build_segment(1), @@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ), WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="str_var", value=build_segment("str_value"), @@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ] node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node1_id, name="str_var", value=build_segment("str_value"), @@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_delete_node_variables(self): srv = self._get_test_srv() srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) - node2_var_count = ( - self._session.query(WorkflowDraftVariable) + node2_var_count = self._session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == self._test_app_id, WorkflowDraftVariable.node_id == self._node2_id, + WorkflowDraftVariable.user_id == self._test_user_id, ) - .count() ) assert node2_var_count == 0 def test_delete_variable(self): srv = self._get_test_srv() - node_1_var = ( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() - ) + node_1_var = self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).one() srv.delete_variable(node_1_var) exists = bool( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).first() ) assert exists is False @@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase): def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( - synchronize_session=False - ) + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id)) session.commit() def test_variable_loader_with_empty_selector(self): @@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: @@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: