fix(api): allow LLM nodes to access retrieved knowledge files (#36175)

This commit is contained in:
-LAN-
2026-05-14 21:09:25 +08:00
committed by GitHub
parent 5798610f27
commit 0d500e6965
10 changed files with 546 additions and 13 deletions

View File

@@ -1,6 +1,13 @@
from .controller import DatabaseFileAccessController
from .protocols import FileAccessControllerProtocol
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
from .scope import (
FileAccessScope,
bind_file_access_scope,
get_current_file_access_scope,
grant_retriever_segment_access,
grant_upload_file_access,
is_retriever_segment_access_granted,
)
__all__ = [
"DatabaseFileAccessController",
@@ -8,4 +15,7 @@ __all__ = [
"FileAccessScope",
"bind_file_access_scope",
"get_current_file_access_scope",
"grant_retriever_segment_access",
"grant_upload_file_access",
"is_retriever_segment_access_granted",
]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
@@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
Tenant scoping remains mandatory. When the current execution belongs to an
end user, the lookup is additionally constrained to that end user's file
ownership markers.
ownership markers, plus upload files explicitly granted by the current
execution context.
"""
_scope_getter: Callable[[], FileAccessScope | None]
@@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
if not resolved_scope.requires_user_ownership:
return scoped_stmt
return scoped_stmt.where(
user_owned_filter = and_(
UploadFile.created_by_role == CreatorUserRole.END_USER,
UploadFile.created_by == resolved_scope.user_id,
)
if not resolved_scope.granted_upload_file_ids:
return scoped_stmt.where(user_owned_filter)
return scoped_stmt.where(
or_(
user_owned_filter,
UploadFile.id.in_(resolved_scope.granted_upload_file_ids),
)
)
def apply_tool_file_filters(
self,

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
from collections.abc import Generator # Changed from Iterator
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from dataclasses import dataclass, field, replace
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
@@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
@dataclass(frozen=True, slots=True)
class FileAccessScope:
"""Request-scoped ownership context used by workflow-layer file lookups."""
"""Request-scoped ownership context used by workflow-layer file lookups.
``granted_upload_file_ids`` is execution-local: callers may add upload files
that were returned by trusted retrieval paths without changing persistent
ownership markers.
``granted_retriever_segment_ids`` gates lazy attachment loading by segment
ID, so user-provided context cannot make a later LLM node load arbitrary
same-tenant knowledge attachments.
"""
tenant_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset)
granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset)
@property
def requires_user_ownership(self) -> bool:
@@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None:
return _current_file_access_scope.get()
def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None:
scope = _current_file_access_scope.get()
if scope is None:
return
granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id)
if not granted_upload_file_ids:
return
_current_file_access_scope.set(
replace(
scope,
granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids,
)
)
def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None:
scope = _current_file_access_scope.get()
if scope is None:
return
granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id)
if not granted_segment_ids:
return
_current_file_access_scope.set(
replace(
scope,
granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids,
)
)
def is_retriever_segment_access_granted(segment_id: str) -> bool:
scope = _current_file_access_scope.get()
if scope is None or not scope.requires_user_ownership:
return True
return str(segment_id) in scope.granted_retriever_segment_ids
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]:
token = _current_file_access_scope.set(scope)
try:
yield

View File

@@ -9,6 +9,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.app.file_access import grant_upload_file_access
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
@@ -890,6 +891,7 @@ class RetrievalService:
.limit(1)
)
if attachment_binding:
grant_upload_file_access([str(upload_file.id)])
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
@@ -906,6 +908,7 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
granted_upload_file_ids: list[str] = []
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@@ -926,6 +929,7 @@ class RetrievalService:
"size": upload_file.size,
}
if attachment_binding:
granted_upload_file_ids.append(str(upload_file.id))
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
@@ -933,4 +937,5 @@ class RetrievalService:
"segment_id": attachment_binding.segment_id,
}
)
grant_upload_file_access(granted_upload_file_ids)
return attachment_infos

View File

@@ -19,6 +19,7 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.app.file_access import grant_retriever_segment_access, grant_upload_file_access
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
@@ -326,6 +327,7 @@ class DatasetRetrieval:
if record.summary:
source.summary = record.summary
grant_retriever_segment_access([str(segment.id)])
retrieval_resource_list.append(source)
if retrieval_resource_list:
@@ -515,6 +517,9 @@ class DatasetRetrieval:
)
).all()
if attachments_with_bindings:
grant_upload_file_access(
str(upload_file.id) for _, upload_file in attachments_with_bindings
)
for _, upload_file in attachments_with_bindings:
attachment_info = File(
file_id=upload_file.id,

View File

@@ -1,6 +1,6 @@
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, final, override
@@ -56,6 +56,7 @@ from graphon.nodes.http_request import build_http_request_config
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
from graphon.variables.segments import ArrayObjectSegment
from models.model import Conversation
if TYPE_CHECKING:
@@ -496,13 +497,47 @@ class DifyNodeFactory(NodeFactory):
if include_prompt_message_serializer:
node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer
if include_retriever_attachment_loader:
node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader
node_init_kwargs["retriever_attachment_loader"] = self._build_retriever_attachment_loader(
cast(LLMNodeData, validated_node_data)
)
if include_jinja2_template_renderer:
node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer
if validated_node_data.type == BuiltinNodeTypes.LLM:
node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY)
return node_init_kwargs
def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader:
return DifyRetrieverAttachmentLoader(
file_reference_factory=self._file_reference_factory,
segment_access_checker=self._build_retriever_segment_access_checker(
node_data.context.variable_selector if node_data.context.enabled else None
),
)
def _build_retriever_segment_access_checker(
self,
context_variable_selector: Sequence[str] | None,
) -> Callable[[str], bool]:
def checker(segment_id: str) -> bool:
if not context_variable_selector:
return False
context_value = self.graph_runtime_state.variable_pool.get(context_variable_selector)
if not isinstance(context_value, ArrayObjectSegment):
return False
for item in context_value.value:
if not isinstance(item, Mapping):
continue
metadata = item.get("metadata")
if not isinstance(metadata, Mapping):
continue
if metadata.get("_source") == "knowledge" and str(metadata.get("segment_id")) == str(segment_id):
return True
return False
return checker
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
node_data_model = node_data.model
model_instance, _ = fetch_model_config(

View File

@@ -8,7 +8,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.file_access import DatabaseFileAccessController
from core.app.file_access import (
DatabaseFileAccessController,
grant_upload_file_access,
is_retriever_segment_access_granted,
)
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.helper.trace_id_helper import ParentTraceContext
from core.llm_generator.output_parser.errors import OutputParserError
@@ -275,10 +279,23 @@ class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
"""Resolve retriever attachments through Dify persistence and return graph file references."""
def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None:
_segment_access_checker: Callable[[str], bool] | None
def __init__(
self,
*,
file_reference_factory: FileReferenceFactoryProtocol,
segment_access_checker: Callable[[str], bool] | None = None,
) -> None:
self._file_reference_factory = file_reference_factory
self._segment_access_checker = segment_access_checker
def load(self, *, segment_id: str) -> Sequence[File]:
if not is_retriever_segment_access_granted(segment_id):
return []
if self._segment_access_checker is not None and not self._segment_access_checker(segment_id):
return []
with Session(db.engine, expire_on_commit=False) as session:
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
@@ -286,6 +303,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
.where(SegmentAttachmentBinding.segment_id == segment_id)
).all()
grant_upload_file_access(str(upload_file.id) for _, upload_file in attachments_with_bindings)
return [
self._file_reference_factory.build_from_mapping(
mapping={

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.file_access import (
DatabaseFileAccessController,
FileAccessScope,
bind_file_access_scope,
get_current_file_access_scope,
grant_retriever_segment_access,
grant_upload_file_access,
is_retriever_segment_access_granted,
)
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval import dataset_retrieval as dataset_retrieval_module
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models import UploadFile
class _ScalarResult:
def __init__(self, values):
self._values = values
def all(self):
return self._values
class _AttachmentSession:
def __init__(self, upload_file, binding):
self._results = [
_ScalarResult([upload_file]),
_ScalarResult([binding]),
]
def scalars(self, _stmt):
return self._results.pop(0)
class _DatasetRetrievalSession:
def __init__(self, datasets, documents):
self._results = [
_ScalarResult(datasets),
_ScalarResult(documents),
]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def scalars(self, _stmt):
return self._results.pop(0)
def test_file_access_grants_ignore_empty_inputs_and_missing_scope() -> None:
grant_upload_file_access(["upload-file-id"])
grant_retriever_segment_access(["segment-id"])
assert is_retriever_segment_access_granted("segment-id") is True
scope = FileAccessScope(
tenant_id=str(uuid4()),
user_id=str(uuid4()),
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
grant_upload_file_access([""])
grant_retriever_segment_access([""])
current_scope = get_current_file_access_scope()
assert current_scope is not None
assert current_scope.granted_upload_file_ids == frozenset()
assert current_scope.granted_retriever_segment_ids == frozenset()
def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope() -> None:
tenant_id = str(uuid4())
upload_file_id = str(uuid4())
segment_id = str(uuid4())
upload_file = SimpleNamespace(
id=upload_file_id,
name="chart.png",
extension="png",
mime_type="image/png",
size=1024,
)
binding = SimpleNamespace(attachment_id=upload_file_id, segment_id=segment_id)
session = _AttachmentSession(upload_file, binding)
scope = FileAccessScope(
tenant_id=tenant_id,
user_id=str(uuid4()),
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
result = RetrievalService.get_segment_attachment_infos([upload_file_id], session) # type: ignore[arg-type]
scoped_stmt = DatabaseFileAccessController().apply_upload_file_filters(
select(UploadFile).where(UploadFile.id == upload_file_id)
)
assert result[0]["attachment_id"] == upload_file_id
whereclause = str(scoped_stmt.whereclause)
assert "upload_files.created_by_role" in whereclause
assert "upload_files.id IN" in whereclause
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None:
tenant_id = str(uuid4())
dataset_id = str(uuid4())
document_id = str(uuid4())
segment_id = str(uuid4())
segment = SimpleNamespace(
id=segment_id,
dataset_id=dataset_id,
document_id=document_id,
hit_count=1,
word_count=10,
position=1,
index_node_hash="hash",
answer=None,
get_sign_content=lambda: "segment content",
)
record = SimpleNamespace(segment=segment, score=0.8, child_chunks=None, files=None, summary=None)
dataset = SimpleNamespace(id=dataset_id, name="Dataset")
document = SimpleNamespace(id=document_id, name="Document", data_source_type="upload_file", doc_metadata={})
retrieval = DatasetRetrieval()
monkeypatch.setattr(retrieval, "_check_knowledge_rate_limit", lambda tenant_id: None)
monkeypatch.setattr(retrieval, "_get_available_datasets", lambda tenant_id, dataset_ids: [dataset])
monkeypatch.setattr(retrieval, "multiple_retrieve", lambda **kwargs: [SimpleNamespace(provider="dify")])
monkeypatch.setattr(RetrievalService, "format_retrieval_documents", lambda documents: [record])
session = _DatasetRetrievalSession([dataset], [document])
monkeypatch.setattr(dataset_retrieval_module.session_factory, "create_session", lambda: session)
scope = FileAccessScope(
tenant_id=tenant_id,
user_id=str(uuid4()),
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
results = retrieval.knowledge_retrieval(
KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=str(uuid4()),
app_id=str(uuid4()),
user_from=UserFrom.END_USER.value,
dataset_ids=[dataset_id],
query="desktop picture",
retrieval_mode="multiple",
)
)
current_scope = get_current_file_access_scope()
assert results[0].metadata.segment_id == segment_id
assert current_scope is not None
assert segment_id in current_scope.granted_retriever_segment_ids

View File

@@ -13,7 +13,8 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.nodes.code.entities import CodeLanguage
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.node import LLMNode
from graphon.variables.segments import StringSegment
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.variables.segments import ArrayObjectSegment, StringSegment
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
@@ -430,6 +431,7 @@ class TestDifyNodeFactoryCreateNode:
factory._http_request_config = sentinel.http_request_config
factory._llm_credentials_provider = sentinel.credentials_provider
factory._llm_model_factory = sentinel.model_factory
factory._build_retriever_attachment_loader = MagicMock(return_value=sentinel.retriever_attachment_loader)
return factory
def test_rejects_unknown_node_type(self, factory):
@@ -770,6 +772,128 @@ class TestDifyNodeFactoryCreateNode:
for key, value in expected_extra_kwargs.items():
assert constructor_kwargs[key] is value
def test_parameter_extractor_init_does_not_require_retriever_context(self, factory):
node_data = ParameterExtractorNodeData.model_validate(
{
"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR,
"title": "Parameter Extractor",
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
"query": ["sys", "query"],
"parameters": [
{
"name": "topic",
"type": "string",
"description": "Topic",
"required": True,
}
],
"reasoning_mode": "prompt",
}
)
factory._build_model_instance_for_llm_node = MagicMock(return_value=sentinel.model_instance)
factory._build_memory_for_llm_node = MagicMock(return_value=sentinel.memory)
factory._build_retriever_attachment_loader = MagicMock(side_effect=AssertionError("unexpected loader build"))
kwargs = factory._build_llm_compatible_node_init_kwargs(
node_class=sentinel.node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
include_prompt_message_serializer=True,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
)
assert "retriever_attachment_loader" not in kwargs
assert kwargs["prompt_message_serializer"] is sentinel.prompt_message_serializer
factory._build_retriever_attachment_loader.assert_not_called()
class TestDifyNodeFactoryRetrieverAttachmentAccess:
@pytest.fixture
def factory(self):
factory = object.__new__(node_factory.DifyNodeFactory)
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
return factory
def test_retriever_attachment_loader_is_typed_for_llm_node_data_only(self):
annotations = node_factory.DifyNodeFactory._build_retriever_attachment_loader.__annotations__
assert annotations["node_data"] is LLMNodeData
def test_build_retriever_attachment_loader_uses_llm_context_selector(self, factory):
factory._file_reference_factory = sentinel.file_reference_factory
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment(
value=[
{
"metadata": {
"_source": "knowledge",
"segment_id": "allowed-segment",
}
}
]
)
node_data = LLMNodeData.model_validate(
{
"type": BuiltinNodeTypes.LLM,
"title": "LLM",
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
"prompt_template": [{"role": "system", "text": "x"}],
"context": {"enabled": True, "variable_selector": ["knowledge-node", "result"]},
"vision": {"enabled": False},
}
)
loader = factory._build_retriever_attachment_loader(node_data)
assert loader._segment_access_checker is not None
assert loader._segment_access_checker("allowed-segment") is True
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(["knowledge-node", "result"])
def test_checker_rejects_missing_context_selector_without_reading_variable_pool(self, factory):
checker = factory._build_retriever_segment_access_checker(None)
assert checker("segment-id") is False
factory.graph_runtime_state.variable_pool.get.assert_not_called()
def test_checker_rejects_non_knowledge_context_items(self, factory):
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment.model_construct(
value=[
"plain-text",
{"metadata": "not-a-mapping"},
]
)
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
assert checker("segment-id") is False
def test_checker_rejects_non_array_context_value(self, factory):
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="not knowledge context")
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
assert checker("segment-id") is False
def test_checker_allows_only_segments_from_selected_knowledge_context(self, factory):
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment(
value=[
{
"metadata": {
"_source": "knowledge",
"segment_id": "allowed-segment",
}
}
]
)
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
assert checker("allowed-segment") is True
assert checker("other-segment") is False
factory.graph_runtime_state.variable_pool.get.assert_any_call(["knowledge-node", "result"])
class TestDifyNodeFactoryModelInstance:
@pytest.fixture

View File

@@ -1,9 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, sentinel
from uuid import uuid4
import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom
from core.app.file_access import FileAccessScope, bind_file_access_scope, grant_retriever_segment_access
from core.llm_generator.output_parser.errors import OutputParserError
from core.workflow import node_runtime
from core.workflow.file_reference import parse_file_reference
@@ -268,6 +270,114 @@ def test_dify_retriever_attachment_loader_builds_graph_files(monkeypatch: pytest
assert parse_file_reference(mapping["reference"]).storage_key is None
def test_dify_retriever_attachment_loader_grants_upload_files_for_allowed_segment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from factories.file_factory import builders as file_builders
upload_file_id = str(uuid4())
segment_id = str(uuid4())
upload_file = SimpleNamespace(
id=upload_file_id,
tenant_id="tenant-id",
name="diagram.png",
extension="png",
mime_type="image/png",
source_url="https://example.com/diagram.png",
key="storage-key",
size=128,
)
attachment_session = MagicMock()
attachment_session.execute.return_value.all.return_value = [(None, upload_file)]
class _AttachmentSessionContext:
def __enter__(self):
return attachment_session
def __exit__(self, exc_type, exc, tb):
return False
upload_session = MagicMock()
upload_session.__enter__.return_value = upload_session
upload_session.__exit__.return_value = False
upload_session.scalar.return_value = upload_file
monkeypatch.setattr(node_runtime, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(node_runtime, "Session", MagicMock(return_value=_AttachmentSessionContext()))
monkeypatch.setattr(file_builders, "session_factory", SimpleNamespace(create_session=lambda: upload_session))
loader = DifyRetrieverAttachmentLoader(file_reference_factory=DifyFileReferenceFactory(_build_run_context()))
scope = FileAccessScope(
tenant_id="tenant-id",
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
grant_retriever_segment_access([segment_id])
files = loader.load(segment_id=segment_id)
assert files[0].related_id == upload_file_id
stmt = upload_session.scalar.call_args.args[0]
whereclause = str(stmt.whereclause)
assert "upload_files.tenant_id" in whereclause
assert "upload_files.id IN" in whereclause
def test_dify_retriever_attachment_loader_skips_ungranted_segment_for_end_user(
monkeypatch: pytest.MonkeyPatch,
) -> None:
build_from_mapping = MagicMock()
session_factory = MagicMock()
monkeypatch.setattr(node_runtime, "Session", session_factory)
loader = DifyRetrieverAttachmentLoader(
file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping)
)
scope = FileAccessScope(
tenant_id="tenant-id",
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
files = loader.load(segment_id=str(uuid4()))
assert files == []
session_factory.assert_not_called()
build_from_mapping.assert_not_called()
def test_dify_retriever_attachment_loader_skips_segment_rejected_by_checker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
segment_id = str(uuid4())
build_from_mapping = MagicMock()
session_factory = MagicMock()
segment_access_checker = MagicMock(return_value=False)
monkeypatch.setattr(node_runtime, "Session", session_factory)
loader = DifyRetrieverAttachmentLoader(
file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping),
segment_access_checker=segment_access_checker,
)
scope = FileAccessScope(
tenant_id="tenant-id",
user_id="end-user-id",
user_from=UserFrom.END_USER,
invoke_from=InvokeFrom.WEB_APP,
)
with bind_file_access_scope(scope):
grant_retriever_segment_access([segment_id])
files = loader.load(segment_id=segment_id)
assert files == []
segment_access_checker.assert_called_once_with(segment_id)
session_factory.assert_not_called()
build_from_mapping.assert_not_called()
def test_dify_tool_file_manager_resolves_conversation_id_for_tool_files(monkeypatch: pytest.MonkeyPatch) -> None:
create_file_by_raw = MagicMock(return_value=SimpleNamespace(id="tool-file-id"))
manager_instance = SimpleNamespace(create_file_by_raw=create_file_by_raw)