From e70e4fa41d74e48a28a1b5cfec9ab0bcaef83f86 Mon Sep 17 00:00:00 2001 From: jerryzai Date: Fri, 17 Apr 2026 04:12:31 -0400 Subject: [PATCH] chore(api): migrate file factory builders and account commands to use Session(db.engine) (#35236) Co-authored-by: Asuka Minato Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands/account.py | 17 +- api/factories/file_factory/builders.py | 186 +++++++++--------- .../factories/test_build_from_mapping.py | 44 ++++- 3 files changed, 141 insertions(+), 106 deletions(-) diff --git a/api/commands/account.py b/api/commands/account.py index 6a2a2e0428..761323a73d 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,6 +2,7 @@ import base64 import secrets import click +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm): # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = db.session.merge(account) - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + session.commit() AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account = db.session.merge(account) - account.email = normalized_new_email - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.email = normalized_new_email + session.commit() click.echo(click.style("Email updated successfully.", fg="green")) diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 288d37d265..ce1fa441c2 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -10,8 +10,8 @@ from typing import Any from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 511192001e..efafc8aa79 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -11,6 +11,21 @@ from factories.file_factory.builders import build_from_mapping as _build_from_ma from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models import ToolFile, UploadFile + +def _make_session_ctx_mock(scalar_return=None): + """Return a mock usable as the ``session_factory.create_session()`` context manager. + + Patch ``factories.file_factory.builders.session_factory`` and set + ``mock_sf.create_session.return_value = `` to intercept DB calls + without requiring a live Flask app or database engine. + """ + session = MagicMock() + session.__enter__.return_value = session + session.__exit__.return_value = False + session.scalar.return_value = scalar_return + return session + + # Test Data TEST_TENANT_ID = "test_tenant_id" TEST_UPLOAD_FILE_ID = str(uuid.uuid4()) @@ -49,8 +64,11 @@ def mock_upload_file(): mock.source_url = TEST_REMOTE_URL mock.size = 1024 mock.key = "test_key" - with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True) as m: - yield m + session = _make_session_ctx_mock(scalar_return=mock) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session + # yield session.scalar so callers can inspect call_args and mutate return_value + yield session.scalar @pytest.fixture @@ -63,7 +81,9 @@ def mock_tool_file(): mock.mimetype = "application/pdf" mock.original_url = "http://example.com/tool.pdf" mock.size = 2048 - with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True): + session = _make_session_ctx_mock(scalar_return=mock) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session yield mock @@ -231,7 +251,9 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head): def test_tool_file_not_found(): """Test ToolFile not found in database.""" - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = tool_file_mapping() with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -239,7 +261,9 @@ def test_tool_file_not_found(): def test_local_file_not_found(): """Test UploadFile not found in database.""" - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -311,7 +335,9 @@ def test_tenant_mismatch(): mock_file.key = "test_key" # Mock the database query to return None (no file found for this tenant) - with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True): + session = _make_session_ctx_mock(scalar_return=None) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -350,11 +376,13 @@ def test_build_from_mapping_scopes_tool_file_to_end_user(): invoke_from=InvokeFrom.WEB_APP, ) - with patch("factories.file_factory.builders.db.session.scalar", return_value=tool_file, autospec=True) as scalar: + session = _make_session_ctx_mock(scalar_return=tool_file) + with patch("factories.file_factory.builders.session_factory") as mock_sf: + mock_sf.create_session.return_value = session with bind_file_access_scope(scope): build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID) - stmt = scalar.call_args.args[0] + stmt = session.scalar.call_args.args[0] whereclause = str(stmt.whereclause) assert "tool_files.user_id" in whereclause