From 0d500e696508d40e51557414807a0e4e14cd7394 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 May 2026 21:09:25 +0800 Subject: [PATCH] fix(api): allow LLM nodes to access retrieved knowledge files (#36175) --- api/core/app/file_access/__init__.py | 12 +- api/core/app/file_access/controller.py | 16 +- api/core/app/file_access/scope.py | 60 ++++++- api/core/rag/datasource/retrieval_service.py | 5 + api/core/rag/retrieval/dataset_retrieval.py | 5 + api/core/workflow/node_factory.py | 39 ++++- api/core/workflow/node_runtime.py | 22 ++- .../test_retrieval_attachment_access.py | 164 ++++++++++++++++++ .../core/workflow/test_node_factory.py | 126 +++++++++++++- .../core/workflow/test_node_runtime.py | 110 ++++++++++++ 10 files changed, 546 insertions(+), 13 deletions(-) create mode 100644 api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py index a75ab9781b..e02aba102b 100644 --- a/api/core/app/file_access/__init__.py +++ b/api/core/app/file_access/__init__.py @@ -1,6 +1,13 @@ from .controller import DatabaseFileAccessController from .protocols import FileAccessControllerProtocol -from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope +from .scope import ( + FileAccessScope, + bind_file_access_scope, + get_current_file_access_scope, + grant_retriever_segment_access, + grant_upload_file_access, + is_retriever_segment_access_granted, +) __all__ = [ "DatabaseFileAccessController", @@ -8,4 +15,7 @@ __all__ = [ "FileAccessScope", "bind_file_access_scope", "get_current_file_access_scope", + "grant_retriever_segment_access", + "grant_upload_file_access", + "is_retriever_segment_access_granted", ] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py index 300c187083..a6c6e74f06 100644 --- a/api/core/app/file_access/controller.py +++ b/api/core/app/file_access/controller.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from sqlalchemy import select +from sqlalchemy import and_, or_, select from sqlalchemy.orm import Session from sqlalchemy.sql import Select @@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): Tenant scoping remains mandatory. When the current execution belongs to an end user, the lookup is additionally constrained to that end user's file - ownership markers. + ownership markers, plus upload files explicitly granted by the current + execution context. """ _scope_getter: Callable[[], FileAccessScope | None] @@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): if not resolved_scope.requires_user_ownership: return scoped_stmt - return scoped_stmt.where( + user_owned_filter = and_( UploadFile.created_by_role == CreatorUserRole.END_USER, UploadFile.created_by == resolved_scope.user_id, ) + if not resolved_scope.granted_upload_file_ids: + return scoped_stmt.where(user_owned_filter) + + return scoped_stmt.where( + or_( + user_owned_filter, + UploadFile.id.in_(resolved_scope.granted_upload_file_ids), + ) + ) def apply_tool_file_filters( self, diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index a583301f9b..12fe7b3840 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,9 +1,9 @@ from __future__ import annotations -from collections.abc import Generator # Changed from Iterator +from collections.abc import Generator, Iterable from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass +from dataclasses import dataclass, field, replace from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( @dataclass(frozen=True, slots=True) class FileAccessScope: - """Request-scoped ownership context used by workflow-layer file lookups.""" + """Request-scoped ownership context used by workflow-layer file lookups. + + ``granted_upload_file_ids`` is execution-local: callers may add upload files + that were returned by trusted retrieval paths without changing persistent + ownership markers. + + ``granted_retriever_segment_ids`` gates lazy attachment loading by segment + ID, so user-provided context cannot make a later LLM node load arbitrary + same-tenant knowledge attachments. + """ tenant_id: str user_id: str user_from: UserFrom invoke_from: InvokeFrom + granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset) + granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset) @property def requires_user_ownership(self) -> bool: @@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None: return _current_file_access_scope.get() +def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None: + scope = _current_file_access_scope.get() + if scope is None: + return + + granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id) + if not granted_upload_file_ids: + return + + _current_file_access_scope.set( + replace( + scope, + granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids, + ) + ) + + +def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None: + scope = _current_file_access_scope.get() + if scope is None: + return + + granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id) + if not granted_segment_ids: + return + + _current_file_access_scope.set( + replace( + scope, + granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids, + ) + ) + + +def is_retriever_segment_access_granted(segment_id: str) -> bool: + scope = _current_file_access_scope.get() + if scope is None or not scope.requires_user_ownership: + return True + return str(segment_id) in scope.granted_retriever_segment_ids + + @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8cc2be8feb..904d5c843f 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -9,6 +9,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.app.file_access import grant_upload_file_access from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict @@ -890,6 +891,7 @@ class RetrievalService: .limit(1) ) if attachment_binding: + grant_upload_file_access([str(upload_file.id)]) attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, @@ -906,6 +908,7 @@ class RetrievalService: cls, attachment_ids: list[str], session: Session ) -> list[SegmentAttachmentInfoResult]: attachment_infos: list[SegmentAttachmentInfoResult] = [] + granted_upload_file_ids: list[str] = [] upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -926,6 +929,7 @@ class RetrievalService: "size": upload_file.size, } if attachment_binding: + granted_upload_file_ids.append(str(upload_file.id)) attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, @@ -933,4 +937,5 @@ class RetrievalService: "segment_id": attachment_binding.segment_id, } ) + grant_upload_file_access(granted_upload_file_ids) return attachment_infos diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 010566d203..039a266f44 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -19,6 +19,7 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.file_access import grant_retriever_segment_access, grant_upload_file_access from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.db.session_factory import session_factory from core.entities.agent_entities import PlanningStrategy @@ -326,6 +327,7 @@ class DatasetRetrieval: if record.summary: source.summary = record.summary + grant_retriever_segment_access([str(segment.id)]) retrieval_resource_list.append(source) if retrieval_resource_list: @@ -515,6 +517,9 @@ class DatasetRetrieval: ) ).all() if attachments_with_bindings: + grant_upload_file_access( + str(upload_file.id) for _, upload_file in attachments_with_bindings + ) for _, upload_file in attachments_with_bindings: attachment_info = File( file_id=upload_file.id, diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index a306b1c9ac..3fcf250c6e 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,6 +1,6 @@ import importlib import pkgutil -from collections.abc import Callable, Iterator, Mapping, MutableMapping +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override @@ -56,6 +56,7 @@ from graphon.nodes.http_request import build_http_request_config from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.variables.segments import ArrayObjectSegment from models.model import Conversation if TYPE_CHECKING: @@ -496,13 +497,47 @@ class DifyNodeFactory(NodeFactory): if include_prompt_message_serializer: node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer if include_retriever_attachment_loader: - node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + node_init_kwargs["retriever_attachment_loader"] = self._build_retriever_attachment_loader( + cast(LLMNodeData, validated_node_data) + ) if include_jinja2_template_renderer: node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer if validated_node_data.type == BuiltinNodeTypes.LLM: node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs + def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader: + return DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + segment_access_checker=self._build_retriever_segment_access_checker( + node_data.context.variable_selector if node_data.context.enabled else None + ), + ) + + def _build_retriever_segment_access_checker( + self, + context_variable_selector: Sequence[str] | None, + ) -> Callable[[str], bool]: + def checker(segment_id: str) -> bool: + if not context_variable_selector: + return False + + context_value = self.graph_runtime_state.variable_pool.get(context_variable_selector) + if not isinstance(context_value, ArrayObjectSegment): + return False + + for item in context_value.value: + if not isinstance(item, Mapping): + continue + metadata = item.get("metadata") + if not isinstance(metadata, Mapping): + continue + if metadata.get("_source") == "knowledge" and str(metadata.get("segment_id")) == str(segment_id): + return True + return False + + return checker + def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model model_instance, _ = fetch_model_config( diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index db7d78bf45..d5fa3a28a6 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -8,7 +8,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.file_access import DatabaseFileAccessController +from core.app.file_access import ( + DatabaseFileAccessController, + grant_upload_file_access, + is_retriever_segment_access_granted, +) from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.helper.trace_id_helper import ParentTraceContext from core.llm_generator.output_parser.errors import OutputParserError @@ -275,10 +279,23 @@ class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): """Resolve retriever attachments through Dify persistence and return graph file references.""" - def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + _segment_access_checker: Callable[[str], bool] | None + + def __init__( + self, + *, + file_reference_factory: FileReferenceFactoryProtocol, + segment_access_checker: Callable[[str], bool] | None = None, + ) -> None: self._file_reference_factory = file_reference_factory + self._segment_access_checker = segment_access_checker def load(self, *, segment_id: str) -> Sequence[File]: + if not is_retriever_segment_access_granted(segment_id): + return [] + if self._segment_access_checker is not None and not self._segment_access_checker(segment_id): + return [] + with Session(db.engine, expire_on_commit=False) as session: attachments_with_bindings = session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -286,6 +303,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): .where(SegmentAttachmentBinding.segment_id == segment_id) ).all() + grant_upload_file_access(str(upload_file.id) for _, upload_file in attachments_with_bindings) return [ self._file_reference_factory.build_from_mapping( mapping={ diff --git a/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py new file mode 100644 index 0000000000..426ffc498b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from types import SimpleNamespace +from uuid import uuid4 + +from sqlalchemy import select + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import ( + DatabaseFileAccessController, + FileAccessScope, + bind_file_access_scope, + get_current_file_access_scope, + grant_retriever_segment_access, + grant_upload_file_access, + is_retriever_segment_access_granted, +) +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.retrieval import dataset_retrieval as dataset_retrieval_module +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from models import UploadFile + + +class _ScalarResult: + def __init__(self, values): + self._values = values + + def all(self): + return self._values + + +class _AttachmentSession: + def __init__(self, upload_file, binding): + self._results = [ + _ScalarResult([upload_file]), + _ScalarResult([binding]), + ] + + def scalars(self, _stmt): + return self._results.pop(0) + + +class _DatasetRetrievalSession: + def __init__(self, datasets, documents): + self._results = [ + _ScalarResult(datasets), + _ScalarResult(documents), + ] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, _stmt): + return self._results.pop(0) + + +def test_file_access_grants_ignore_empty_inputs_and_missing_scope() -> None: + grant_upload_file_access(["upload-file-id"]) + grant_retriever_segment_access(["segment-id"]) + assert is_retriever_segment_access_granted("segment-id") is True + + scope = FileAccessScope( + tenant_id=str(uuid4()), + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_upload_file_access([""]) + grant_retriever_segment_access([""]) + current_scope = get_current_file_access_scope() + + assert current_scope is not None + assert current_scope.granted_upload_file_ids == frozenset() + assert current_scope.granted_retriever_segment_ids == frozenset() + + +def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope() -> None: + tenant_id = str(uuid4()) + upload_file_id = str(uuid4()) + segment_id = str(uuid4()) + upload_file = SimpleNamespace( + id=upload_file_id, + name="chart.png", + extension="png", + mime_type="image/png", + size=1024, + ) + binding = SimpleNamespace(attachment_id=upload_file_id, segment_id=segment_id) + session = _AttachmentSession(upload_file, binding) + scope = FileAccessScope( + tenant_id=tenant_id, + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + result = RetrievalService.get_segment_attachment_infos([upload_file_id], session) # type: ignore[arg-type] + scoped_stmt = DatabaseFileAccessController().apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == upload_file_id) + ) + + assert result[0]["attachment_id"] == upload_file_id + whereclause = str(scoped_stmt.whereclause) + assert "upload_files.created_by_role" in whereclause + assert "upload_files.id IN" in whereclause + + +def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None: + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + document_id = str(uuid4()) + segment_id = str(uuid4()) + segment = SimpleNamespace( + id=segment_id, + dataset_id=dataset_id, + document_id=document_id, + hit_count=1, + word_count=10, + position=1, + index_node_hash="hash", + answer=None, + get_sign_content=lambda: "segment content", + ) + record = SimpleNamespace(segment=segment, score=0.8, child_chunks=None, files=None, summary=None) + dataset = SimpleNamespace(id=dataset_id, name="Dataset") + document = SimpleNamespace(id=document_id, name="Document", data_source_type="upload_file", doc_metadata={}) + retrieval = DatasetRetrieval() + monkeypatch.setattr(retrieval, "_check_knowledge_rate_limit", lambda tenant_id: None) + monkeypatch.setattr(retrieval, "_get_available_datasets", lambda tenant_id, dataset_ids: [dataset]) + monkeypatch.setattr(retrieval, "multiple_retrieve", lambda **kwargs: [SimpleNamespace(provider="dify")]) + monkeypatch.setattr(RetrievalService, "format_retrieval_documents", lambda documents: [record]) + session = _DatasetRetrievalSession([dataset], [document]) + monkeypatch.setattr(dataset_retrieval_module.session_factory, "create_session", lambda: session) + scope = FileAccessScope( + tenant_id=tenant_id, + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + results = retrieval.knowledge_retrieval( + KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=str(uuid4()), + app_id=str(uuid4()), + user_from=UserFrom.END_USER.value, + dataset_ids=[dataset_id], + query="desktop picture", + retrieval_mode="multiple", + ) + ) + current_scope = get_current_file_access_scope() + + assert results[0].metadata.segment_id == segment_id + assert current_scope is not None + assert segment_id in current_scope.granted_retriever_segment_ids diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index d6159e84d4..9a7fcf166f 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -13,7 +13,8 @@ from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.code.entities import CodeLanguage from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.node import LLMNode -from graphon.variables.segments import StringSegment +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.variables.segments import ArrayObjectSegment, StringSegment def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None: @@ -430,6 +431,7 @@ class TestDifyNodeFactoryCreateNode: factory._http_request_config = sentinel.http_request_config factory._llm_credentials_provider = sentinel.credentials_provider factory._llm_model_factory = sentinel.model_factory + factory._build_retriever_attachment_loader = MagicMock(return_value=sentinel.retriever_attachment_loader) return factory def test_rejects_unknown_node_type(self, factory): @@ -770,6 +772,128 @@ class TestDifyNodeFactoryCreateNode: for key, value in expected_extra_kwargs.items(): assert constructor_kwargs[key] is value + def test_parameter_extractor_init_does_not_require_retriever_context(self, factory): + node_data = ParameterExtractorNodeData.model_validate( + { + "type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, + "title": "Parameter Extractor", + "model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [ + { + "name": "topic", + "type": "string", + "description": "Topic", + "required": True, + } + ], + "reasoning_mode": "prompt", + } + ) + factory._build_model_instance_for_llm_node = MagicMock(return_value=sentinel.model_instance) + factory._build_memory_for_llm_node = MagicMock(return_value=sentinel.memory) + factory._build_retriever_attachment_loader = MagicMock(side_effect=AssertionError("unexpected loader build")) + + kwargs = factory._build_llm_compatible_node_init_kwargs( + node_class=sentinel.node_class, + node_data=node_data, + wrap_model_instance=True, + include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, + ) + + assert "retriever_attachment_loader" not in kwargs + assert kwargs["prompt_message_serializer"] is sentinel.prompt_message_serializer + factory._build_retriever_attachment_loader.assert_not_called() + + +class TestDifyNodeFactoryRetrieverAttachmentAccess: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + return factory + + def test_retriever_attachment_loader_is_typed_for_llm_node_data_only(self): + annotations = node_factory.DifyNodeFactory._build_retriever_attachment_loader.__annotations__ + + assert annotations["node_data"] is LLMNodeData + + def test_build_retriever_attachment_loader_uses_llm_context_selector(self, factory): + factory._file_reference_factory = sentinel.file_reference_factory + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment( + value=[ + { + "metadata": { + "_source": "knowledge", + "segment_id": "allowed-segment", + } + } + ] + ) + node_data = LLMNodeData.model_validate( + { + "type": BuiltinNodeTypes.LLM, + "title": "LLM", + "model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}}, + "prompt_template": [{"role": "system", "text": "x"}], + "context": {"enabled": True, "variable_selector": ["knowledge-node", "result"]}, + "vision": {"enabled": False}, + } + ) + + loader = factory._build_retriever_attachment_loader(node_data) + + assert loader._segment_access_checker is not None + assert loader._segment_access_checker("allowed-segment") is True + factory.graph_runtime_state.variable_pool.get.assert_called_once_with(["knowledge-node", "result"]) + + def test_checker_rejects_missing_context_selector_without_reading_variable_pool(self, factory): + checker = factory._build_retriever_segment_access_checker(None) + + assert checker("segment-id") is False + factory.graph_runtime_state.variable_pool.get.assert_not_called() + + def test_checker_rejects_non_knowledge_context_items(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment.model_construct( + value=[ + "plain-text", + {"metadata": "not-a-mapping"}, + ] + ) + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("segment-id") is False + + def test_checker_rejects_non_array_context_value(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="not knowledge context") + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("segment-id") is False + + def test_checker_allows_only_segments_from_selected_knowledge_context(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment( + value=[ + { + "metadata": { + "_source": "knowledge", + "segment_id": "allowed-segment", + } + } + ] + ) + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("allowed-segment") is True + assert checker("other-segment") is False + factory.graph_runtime_state.variable_pool.get.assert_any_call(["knowledge-node", "result"]) + class TestDifyNodeFactoryModelInstance: @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index d2925fd1a8..5e83863dc2 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -1,9 +1,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, sentinel +from uuid import uuid4 import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom +from core.app.file_access import FileAccessScope, bind_file_access_scope, grant_retriever_segment_access from core.llm_generator.output_parser.errors import OutputParserError from core.workflow import node_runtime from core.workflow.file_reference import parse_file_reference @@ -268,6 +270,114 @@ def test_dify_retriever_attachment_loader_builds_graph_files(monkeypatch: pytest assert parse_file_reference(mapping["reference"]).storage_key is None +def test_dify_retriever_attachment_loader_grants_upload_files_for_allowed_segment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from factories.file_factory import builders as file_builders + + upload_file_id = str(uuid4()) + segment_id = str(uuid4()) + upload_file = SimpleNamespace( + id=upload_file_id, + tenant_id="tenant-id", + name="diagram.png", + extension="png", + mime_type="image/png", + source_url="https://example.com/diagram.png", + key="storage-key", + size=128, + ) + attachment_session = MagicMock() + attachment_session.execute.return_value.all.return_value = [(None, upload_file)] + + class _AttachmentSessionContext: + def __enter__(self): + return attachment_session + + def __exit__(self, exc_type, exc, tb): + return False + + upload_session = MagicMock() + upload_session.__enter__.return_value = upload_session + upload_session.__exit__.return_value = False + upload_session.scalar.return_value = upload_file + + monkeypatch.setattr(node_runtime, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(node_runtime, "Session", MagicMock(return_value=_AttachmentSessionContext())) + monkeypatch.setattr(file_builders, "session_factory", SimpleNamespace(create_session=lambda: upload_session)) + + loader = DifyRetrieverAttachmentLoader(file_reference_factory=DifyFileReferenceFactory(_build_run_context())) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_retriever_segment_access([segment_id]) + files = loader.load(segment_id=segment_id) + + assert files[0].related_id == upload_file_id + stmt = upload_session.scalar.call_args.args[0] + whereclause = str(stmt.whereclause) + assert "upload_files.tenant_id" in whereclause + assert "upload_files.id IN" in whereclause + + +def test_dify_retriever_attachment_loader_skips_ungranted_segment_for_end_user( + monkeypatch: pytest.MonkeyPatch, +) -> None: + build_from_mapping = MagicMock() + session_factory = MagicMock() + monkeypatch.setattr(node_runtime, "Session", session_factory) + loader = DifyRetrieverAttachmentLoader( + file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping) + ) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + files = loader.load(segment_id=str(uuid4())) + + assert files == [] + session_factory.assert_not_called() + build_from_mapping.assert_not_called() + + +def test_dify_retriever_attachment_loader_skips_segment_rejected_by_checker( + monkeypatch: pytest.MonkeyPatch, +) -> None: + segment_id = str(uuid4()) + build_from_mapping = MagicMock() + session_factory = MagicMock() + segment_access_checker = MagicMock(return_value=False) + monkeypatch.setattr(node_runtime, "Session", session_factory) + loader = DifyRetrieverAttachmentLoader( + file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping), + segment_access_checker=segment_access_checker, + ) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_retriever_segment_access([segment_id]) + files = loader.load(segment_id=segment_id) + + assert files == [] + segment_access_checker.assert_called_once_with(segment_id) + session_factory.assert_not_called() + build_from_mapping.assert_not_called() + + def test_dify_tool_file_manager_resolves_conversation_id_for_tool_files(monkeypatch: pytest.MonkeyPatch) -> None: create_file_by_raw = MagicMock(return_value=SimpleNamespace(id="tool-file-id")) manager_instance = SimpleNamespace(create_file_by_raw=create_file_by_raw)