diff --git a/api/models/model.py b/api/models/model.py index d8dccba687..62a2a76c81 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto @@ -20,13 +20,13 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from core.workflow.file_reference import parse_file_reference from dify_graph.enums import WorkflowExecutionStatus from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string @@ -53,19 +53,23 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- -def _resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: - reference = file_mapping.get("reference") - if isinstance(reference, str) and reference: - parsed_reference = parse_file_reference(reference) - if parsed_reference is not None: - return parsed_reference.record_id +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return cast(str, resolved_tenant_id) - related_id = file_mapping.get("related_id") - if isinstance(related_id, str) and related_id: - parsed_reference = parse_file_reference(related_id) - if parsed_reference is not None: - return parsed_reference.record_id - return None + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id class EnabledConfig(TypedDict): @@ -1062,24 +1066,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `dify_graph.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - record_id = _resolve_file_record_id(value_dict) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = record_id - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = record_id - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1092,16 +1098,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - record_id = _resolve_file_record_id(item_dict) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = record_id - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = record_id - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1407,22 +1409,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - record_id = _resolve_file_record_id(value_dict) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = record_id - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = record_id - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1435,16 +1438,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - record_id = _resolve_file_record_id(item_dict) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = record_id - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = record_id - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 0000000000..b390b8106b --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 0000000000..c50f8c898a --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from typing import Any, cast + +from core.workflow.file_reference import parse_file_reference +from dify_graph.file import File, FileTransferMethod + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `dify_graph.file.File`. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + record_id = resolve_file_record_id(mapping) + + if mapping["transfer_method"] == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = record_id + elif mapping["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: + mapping["upload_file_id"] = record_id + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=mapping, tenant_resolver=tenant_resolver) + return cast(File, file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)) diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index 1a2003a9cf..d4bf6499f6 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -3,7 +3,9 @@ import types import pytest -from models.model import Message +from core.workflow.file_reference import build_file_reference +from dify_graph.file import FILE_MODEL_IDENTITY, FileTransferMethod +from models.model import Conversation, Message @pytest.fixture(autouse=True) @@ -81,3 +83,131 @@ def test_image_preview_misspelled_not_replaced(): out = msg.re_sign_file_url_answer # Expect NO replacement, should not rewrite misspelled image-previe URL assert out == original + + +def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict[str, object]: + mapping: dict[str, object] = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "reference": build_file_reference(record_id=record_id), + "type": "document", + "filename": "example.txt", + "extension": ".txt", + "mime_type": "text/plain", + "size": 1, + } + if tenant_id is not None: + mapping["tenant_id"] = tenant_id + return mapping + + +@pytest.mark.parametrize("owner_cls", [Conversation, Message]) +def test_inputs_resolve_owner_tenant_for_single_file_mapping( + monkeypatch: pytest.MonkeyPatch, + owner_cls: type[Conversation] | type[Message], +): + model_module = importlib.import_module("models.model") + build_calls: list[tuple[dict[str, object], str]] = [] + + monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app") + + def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False): + _ = config, strict_type_validation + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) + + owner = owner_cls(app_id="app-1") + owner.inputs = {"file": _build_local_file_mapping("upload-1")} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"} + assert build_calls == [ + ( + { + **_build_local_file_mapping("upload-1"), + "upload_file_id": "upload-1", + }, + "tenant-from-app", + ) + ] + + +@pytest.mark.parametrize("owner_cls", [Conversation, Message]) +def test_inputs_resolve_owner_tenant_for_file_list_mapping( + monkeypatch: pytest.MonkeyPatch, + owner_cls: type[Conversation] | type[Message], +): + model_module = importlib.import_module("models.model") + build_calls: list[tuple[dict[str, object], str]] = [] + + monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app") + + def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False): + _ = config, strict_type_validation + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) + + owner = owner_cls(app_id="app-1") + owner.inputs = { + "files": [ + _build_local_file_mapping("upload-1"), + _build_local_file_mapping("upload-2"), + ] + } + + restored_inputs = owner.inputs + + assert restored_inputs["files"] == [ + {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"}, + {"tenant_id": "tenant-from-app", "upload_file_id": "upload-2"}, + ] + assert build_calls == [ + ( + { + **_build_local_file_mapping("upload-1"), + "upload_file_id": "upload-1", + }, + "tenant-from-app", + ), + ( + { + **_build_local_file_mapping("upload-2"), + "upload_file_id": "upload-2", + }, + "tenant-from-app", + ), + ] + + +@pytest.mark.parametrize("owner_cls", [Conversation, Message]) +def test_inputs_prefer_serialized_tenant_id_when_present( + monkeypatch: pytest.MonkeyPatch, + owner_cls: type[Conversation] | type[Message], +): + model_module = importlib.import_module("models.model") + + def fail_if_called(_): + raise AssertionError("App tenant lookup should not run when tenant_id exists in the file mapping") + + monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called) + + def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False): + _ = config, strict_type_validation + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) + + owner = owner_cls(app_id="app-1") + owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id="tenant-from-payload")} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == { + "tenant_id": "tenant-from-payload", + "upload_file_id": "upload-1", + }