mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 13:00:51 -04:00
fix(api): allow LLM nodes to access retrieved knowledge files (#36175)
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
from .controller import DatabaseFileAccessController
|
||||
from .protocols import FileAccessControllerProtocol
|
||||
from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope
|
||||
from .scope import (
|
||||
FileAccessScope,
|
||||
bind_file_access_scope,
|
||||
get_current_file_access_scope,
|
||||
grant_retriever_segment_access,
|
||||
grant_upload_file_access,
|
||||
is_retriever_segment_access_granted,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DatabaseFileAccessController",
|
||||
@@ -8,4 +15,7 @@ __all__ = [
|
||||
"FileAccessScope",
|
||||
"bind_file_access_scope",
|
||||
"get_current_file_access_scope",
|
||||
"grant_retriever_segment_access",
|
||||
"grant_upload_file_access",
|
||||
"is_retriever_segment_access_granted",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
Tenant scoping remains mandatory. When the current execution belongs to an
|
||||
end user, the lookup is additionally constrained to that end user's file
|
||||
ownership markers.
|
||||
ownership markers, plus upload files explicitly granted by the current
|
||||
execution context.
|
||||
"""
|
||||
|
||||
_scope_getter: Callable[[], FileAccessScope | None]
|
||||
@@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
if not resolved_scope.requires_user_ownership:
|
||||
return scoped_stmt
|
||||
|
||||
return scoped_stmt.where(
|
||||
user_owned_filter = and_(
|
||||
UploadFile.created_by_role == CreatorUserRole.END_USER,
|
||||
UploadFile.created_by == resolved_scope.user_id,
|
||||
)
|
||||
if not resolved_scope.granted_upload_file_ids:
|
||||
return scoped_stmt.where(user_owned_filter)
|
||||
|
||||
return scoped_stmt.where(
|
||||
or_(
|
||||
user_owned_filter,
|
||||
UploadFile.id.in_(resolved_scope.granted_upload_file_ids),
|
||||
)
|
||||
)
|
||||
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from collections.abc import Generator, Iterable
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
|
||||
@@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar(
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileAccessScope:
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups."""
|
||||
"""Request-scoped ownership context used by workflow-layer file lookups.
|
||||
|
||||
``granted_upload_file_ids`` is execution-local: callers may add upload files
|
||||
that were returned by trusted retrieval paths without changing persistent
|
||||
ownership markers.
|
||||
|
||||
``granted_retriever_segment_ids`` gates lazy attachment loading by segment
|
||||
ID, so user-provided context cannot make a later LLM node load arbitrary
|
||||
same-tenant knowledge attachments.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset)
|
||||
granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset)
|
||||
|
||||
@property
|
||||
def requires_user_ownership(self) -> bool:
|
||||
@@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
return _current_file_access_scope.get()
|
||||
|
||||
|
||||
def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None:
|
||||
return
|
||||
|
||||
granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id)
|
||||
if not granted_upload_file_ids:
|
||||
return
|
||||
|
||||
_current_file_access_scope.set(
|
||||
replace(
|
||||
scope,
|
||||
granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None:
|
||||
return
|
||||
|
||||
granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id)
|
||||
if not granted_segment_ids:
|
||||
return
|
||||
|
||||
_current_file_access_scope.set(
|
||||
replace(
|
||||
scope,
|
||||
granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_retriever_segment_access_granted(segment_id: str) -> bool:
|
||||
scope = _current_file_access_scope.get()
|
||||
if scope is None or not scope.requires_user_ownership:
|
||||
return True
|
||||
return str(segment_id) in scope.granted_retriever_segment_ids
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import grant_upload_file_access
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
@@ -890,6 +891,7 @@ class RetrievalService:
|
||||
.limit(1)
|
||||
)
|
||||
if attachment_binding:
|
||||
grant_upload_file_access([str(upload_file.id)])
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
@@ -906,6 +908,7 @@ class RetrievalService:
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
granted_upload_file_ids: list[str] = []
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
@@ -926,6 +929,7 @@ class RetrievalService:
|
||||
"size": upload_file.size,
|
||||
}
|
||||
if attachment_binding:
|
||||
granted_upload_file_ids.append(str(upload_file.id))
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
@@ -933,4 +937,5 @@ class RetrievalService:
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
grant_upload_file_access(granted_upload_file_ids)
|
||||
return attachment_infos
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.app.app_config.entities import (
|
||||
ModelConfig,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.app.file_access import grant_retriever_segment_access, grant_upload_file_access
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
@@ -326,6 +327,7 @@ class DatasetRetrieval:
|
||||
if record.summary:
|
||||
source.summary = record.summary
|
||||
|
||||
grant_retriever_segment_access([str(segment.id)])
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if retrieval_resource_list:
|
||||
@@ -515,6 +517,9 @@ class DatasetRetrieval:
|
||||
)
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
grant_upload_file_access(
|
||||
str(upload_file.id) for _, upload_file in attachments_with_bindings
|
||||
)
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
file_id=upload_file.id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||
@@ -56,6 +56,7 @@ from graphon.nodes.http_request import build_http_request_config
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from graphon.variables.segments import ArrayObjectSegment
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -496,13 +497,47 @@ class DifyNodeFactory(NodeFactory):
|
||||
if include_prompt_message_serializer:
|
||||
node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer
|
||||
if include_retriever_attachment_loader:
|
||||
node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader
|
||||
node_init_kwargs["retriever_attachment_loader"] = self._build_retriever_attachment_loader(
|
||||
cast(LLMNodeData, validated_node_data)
|
||||
)
|
||||
if include_jinja2_template_renderer:
|
||||
node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer
|
||||
if validated_node_data.type == BuiltinNodeTypes.LLM:
|
||||
node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY)
|
||||
return node_init_kwargs
|
||||
|
||||
def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader:
|
||||
return DifyRetrieverAttachmentLoader(
|
||||
file_reference_factory=self._file_reference_factory,
|
||||
segment_access_checker=self._build_retriever_segment_access_checker(
|
||||
node_data.context.variable_selector if node_data.context.enabled else None
|
||||
),
|
||||
)
|
||||
|
||||
def _build_retriever_segment_access_checker(
|
||||
self,
|
||||
context_variable_selector: Sequence[str] | None,
|
||||
) -> Callable[[str], bool]:
|
||||
def checker(segment_id: str) -> bool:
|
||||
if not context_variable_selector:
|
||||
return False
|
||||
|
||||
context_value = self.graph_runtime_state.variable_pool.get(context_variable_selector)
|
||||
if not isinstance(context_value, ArrayObjectSegment):
|
||||
return False
|
||||
|
||||
for item in context_value.value:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
metadata = item.get("metadata")
|
||||
if not isinstance(metadata, Mapping):
|
||||
continue
|
||||
if metadata.get("_source") == "knowledge" and str(metadata.get("segment_id")) == str(segment_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
return checker
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance:
|
||||
node_data_model = node_data.model
|
||||
model_instance, _ = fetch_model_config(
|
||||
|
||||
@@ -8,7 +8,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.file_access import (
|
||||
DatabaseFileAccessController,
|
||||
grant_upload_file_access,
|
||||
is_retriever_segment_access_granted,
|
||||
)
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@@ -275,10 +279,23 @@ class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
|
||||
class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
|
||||
"""Resolve retriever attachments through Dify persistence and return graph file references."""
|
||||
|
||||
def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None:
|
||||
_segment_access_checker: Callable[[str], bool] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_reference_factory: FileReferenceFactoryProtocol,
|
||||
segment_access_checker: Callable[[str], bool] | None = None,
|
||||
) -> None:
|
||||
self._file_reference_factory = file_reference_factory
|
||||
self._segment_access_checker = segment_access_checker
|
||||
|
||||
def load(self, *, segment_id: str) -> Sequence[File]:
|
||||
if not is_retriever_segment_access_granted(segment_id):
|
||||
return []
|
||||
if self._segment_access_checker is not None and not self._segment_access_checker(segment_id):
|
||||
return []
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
attachments_with_bindings = session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
@@ -286,6 +303,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol):
|
||||
.where(SegmentAttachmentBinding.segment_id == segment_id)
|
||||
).all()
|
||||
|
||||
grant_upload_file_access(str(upload_file.id) for _, upload_file in attachments_with_bindings)
|
||||
return [
|
||||
self._file_reference_factory.build_from_mapping(
|
||||
mapping={
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.file_access import (
|
||||
DatabaseFileAccessController,
|
||||
FileAccessScope,
|
||||
bind_file_access_scope,
|
||||
get_current_file_access_scope,
|
||||
grant_retriever_segment_access,
|
||||
grant_upload_file_access,
|
||||
is_retriever_segment_access_granted,
|
||||
)
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.retrieval import dataset_retrieval as dataset_retrieval_module
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models import UploadFile
|
||||
|
||||
|
||||
class _ScalarResult:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def all(self):
|
||||
return self._values
|
||||
|
||||
|
||||
class _AttachmentSession:
|
||||
def __init__(self, upload_file, binding):
|
||||
self._results = [
|
||||
_ScalarResult([upload_file]),
|
||||
_ScalarResult([binding]),
|
||||
]
|
||||
|
||||
def scalars(self, _stmt):
|
||||
return self._results.pop(0)
|
||||
|
||||
|
||||
class _DatasetRetrievalSession:
|
||||
def __init__(self, datasets, documents):
|
||||
self._results = [
|
||||
_ScalarResult(datasets),
|
||||
_ScalarResult(documents),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def scalars(self, _stmt):
|
||||
return self._results.pop(0)
|
||||
|
||||
|
||||
def test_file_access_grants_ignore_empty_inputs_and_missing_scope() -> None:
|
||||
grant_upload_file_access(["upload-file-id"])
|
||||
grant_retriever_segment_access(["segment-id"])
|
||||
assert is_retriever_segment_access_granted("segment-id") is True
|
||||
|
||||
scope = FileAccessScope(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
grant_upload_file_access([""])
|
||||
grant_retriever_segment_access([""])
|
||||
current_scope = get_current_file_access_scope()
|
||||
|
||||
assert current_scope is not None
|
||||
assert current_scope.granted_upload_file_ids == frozenset()
|
||||
assert current_scope.granted_retriever_segment_ids == frozenset()
|
||||
|
||||
|
||||
def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope() -> None:
|
||||
tenant_id = str(uuid4())
|
||||
upload_file_id = str(uuid4())
|
||||
segment_id = str(uuid4())
|
||||
upload_file = SimpleNamespace(
|
||||
id=upload_file_id,
|
||||
name="chart.png",
|
||||
extension="png",
|
||||
mime_type="image/png",
|
||||
size=1024,
|
||||
)
|
||||
binding = SimpleNamespace(attachment_id=upload_file_id, segment_id=segment_id)
|
||||
session = _AttachmentSession(upload_file, binding)
|
||||
scope = FileAccessScope(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
result = RetrievalService.get_segment_attachment_infos([upload_file_id], session) # type: ignore[arg-type]
|
||||
scoped_stmt = DatabaseFileAccessController().apply_upload_file_filters(
|
||||
select(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
)
|
||||
|
||||
assert result[0]["attachment_id"] == upload_file_id
|
||||
whereclause = str(scoped_stmt.whereclause)
|
||||
assert "upload_files.created_by_role" in whereclause
|
||||
assert "upload_files.id IN" in whereclause
|
||||
|
||||
|
||||
def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None:
|
||||
tenant_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
segment_id = str(uuid4())
|
||||
segment = SimpleNamespace(
|
||||
id=segment_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
hit_count=1,
|
||||
word_count=10,
|
||||
position=1,
|
||||
index_node_hash="hash",
|
||||
answer=None,
|
||||
get_sign_content=lambda: "segment content",
|
||||
)
|
||||
record = SimpleNamespace(segment=segment, score=0.8, child_chunks=None, files=None, summary=None)
|
||||
dataset = SimpleNamespace(id=dataset_id, name="Dataset")
|
||||
document = SimpleNamespace(id=document_id, name="Document", data_source_type="upload_file", doc_metadata={})
|
||||
retrieval = DatasetRetrieval()
|
||||
monkeypatch.setattr(retrieval, "_check_knowledge_rate_limit", lambda tenant_id: None)
|
||||
monkeypatch.setattr(retrieval, "_get_available_datasets", lambda tenant_id, dataset_ids: [dataset])
|
||||
monkeypatch.setattr(retrieval, "multiple_retrieve", lambda **kwargs: [SimpleNamespace(provider="dify")])
|
||||
monkeypatch.setattr(RetrievalService, "format_retrieval_documents", lambda documents: [record])
|
||||
session = _DatasetRetrievalSession([dataset], [document])
|
||||
monkeypatch.setattr(dataset_retrieval_module.session_factory, "create_session", lambda: session)
|
||||
scope = FileAccessScope(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
results = retrieval.knowledge_retrieval(
|
||||
KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
app_id=str(uuid4()),
|
||||
user_from=UserFrom.END_USER.value,
|
||||
dataset_ids=[dataset_id],
|
||||
query="desktop picture",
|
||||
retrieval_mode="multiple",
|
||||
)
|
||||
)
|
||||
current_scope = get_current_file_access_scope()
|
||||
|
||||
assert results[0].metadata.segment_id == segment_id
|
||||
assert current_scope is not None
|
||||
assert segment_id in current_scope.granted_retriever_segment_ids
|
||||
@@ -13,7 +13,8 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.variables.segments import StringSegment
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.variables.segments import ArrayObjectSegment, StringSegment
|
||||
|
||||
|
||||
def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
@@ -430,6 +431,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._http_request_config = sentinel.http_request_config
|
||||
factory._llm_credentials_provider = sentinel.credentials_provider
|
||||
factory._llm_model_factory = sentinel.model_factory
|
||||
factory._build_retriever_attachment_loader = MagicMock(return_value=sentinel.retriever_attachment_loader)
|
||||
return factory
|
||||
|
||||
def test_rejects_unknown_node_type(self, factory):
|
||||
@@ -770,6 +772,128 @@ class TestDifyNodeFactoryCreateNode:
|
||||
for key, value in expected_extra_kwargs.items():
|
||||
assert constructor_kwargs[key] is value
|
||||
|
||||
def test_parameter_extractor_init_does_not_require_retriever_context(self, factory):
|
||||
node_data = ParameterExtractorNodeData.model_validate(
|
||||
{
|
||||
"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
"title": "Parameter Extractor",
|
||||
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
|
||||
"query": ["sys", "query"],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "topic",
|
||||
"type": "string",
|
||||
"description": "Topic",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"reasoning_mode": "prompt",
|
||||
}
|
||||
)
|
||||
factory._build_model_instance_for_llm_node = MagicMock(return_value=sentinel.model_instance)
|
||||
factory._build_memory_for_llm_node = MagicMock(return_value=sentinel.memory)
|
||||
factory._build_retriever_attachment_loader = MagicMock(side_effect=AssertionError("unexpected loader build"))
|
||||
|
||||
kwargs = factory._build_llm_compatible_node_init_kwargs(
|
||||
node_class=sentinel.node_class,
|
||||
node_data=node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
include_prompt_message_serializer=True,
|
||||
include_retriever_attachment_loader=False,
|
||||
include_jinja2_template_renderer=False,
|
||||
)
|
||||
|
||||
assert "retriever_attachment_loader" not in kwargs
|
||||
assert kwargs["prompt_message_serializer"] is sentinel.prompt_message_serializer
|
||||
factory._build_retriever_attachment_loader.assert_not_called()
|
||||
|
||||
|
||||
class TestDifyNodeFactoryRetrieverAttachmentAccess:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
|
||||
return factory
|
||||
|
||||
def test_retriever_attachment_loader_is_typed_for_llm_node_data_only(self):
|
||||
annotations = node_factory.DifyNodeFactory._build_retriever_attachment_loader.__annotations__
|
||||
|
||||
assert annotations["node_data"] is LLMNodeData
|
||||
|
||||
def test_build_retriever_attachment_loader_uses_llm_context_selector(self, factory):
|
||||
factory._file_reference_factory = sentinel.file_reference_factory
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment(
|
||||
value=[
|
||||
{
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"segment_id": "allowed-segment",
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
node_data = LLMNodeData.model_validate(
|
||||
{
|
||||
"type": BuiltinNodeTypes.LLM,
|
||||
"title": "LLM",
|
||||
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
|
||||
"prompt_template": [{"role": "system", "text": "x"}],
|
||||
"context": {"enabled": True, "variable_selector": ["knowledge-node", "result"]},
|
||||
"vision": {"enabled": False},
|
||||
}
|
||||
)
|
||||
|
||||
loader = factory._build_retriever_attachment_loader(node_data)
|
||||
|
||||
assert loader._segment_access_checker is not None
|
||||
assert loader._segment_access_checker("allowed-segment") is True
|
||||
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(["knowledge-node", "result"])
|
||||
|
||||
def test_checker_rejects_missing_context_selector_without_reading_variable_pool(self, factory):
|
||||
checker = factory._build_retriever_segment_access_checker(None)
|
||||
|
||||
assert checker("segment-id") is False
|
||||
factory.graph_runtime_state.variable_pool.get.assert_not_called()
|
||||
|
||||
def test_checker_rejects_non_knowledge_context_items(self, factory):
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment.model_construct(
|
||||
value=[
|
||||
"plain-text",
|
||||
{"metadata": "not-a-mapping"},
|
||||
]
|
||||
)
|
||||
|
||||
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
|
||||
|
||||
assert checker("segment-id") is False
|
||||
|
||||
def test_checker_rejects_non_array_context_value(self, factory):
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="not knowledge context")
|
||||
|
||||
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
|
||||
|
||||
assert checker("segment-id") is False
|
||||
|
||||
def test_checker_allows_only_segments_from_selected_knowledge_context(self, factory):
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment(
|
||||
value=[
|
||||
{
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"segment_id": "allowed-segment",
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"])
|
||||
|
||||
assert checker("allowed-segment") is True
|
||||
assert checker("other-segment") is False
|
||||
factory.graph_runtime_state.variable_pool.get.assert_any_call(["knowledge-node", "result"])
|
||||
|
||||
|
||||
class TestDifyNodeFactoryModelInstance:
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, sentinel
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.app.file_access import FileAccessScope, bind_file_access_scope, grant_retriever_segment_access
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.workflow import node_runtime
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
@@ -268,6 +270,114 @@ def test_dify_retriever_attachment_loader_builds_graph_files(monkeypatch: pytest
|
||||
assert parse_file_reference(mapping["reference"]).storage_key is None
|
||||
|
||||
|
||||
def test_dify_retriever_attachment_loader_grants_upload_files_for_allowed_segment(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from factories.file_factory import builders as file_builders
|
||||
|
||||
upload_file_id = str(uuid4())
|
||||
segment_id = str(uuid4())
|
||||
upload_file = SimpleNamespace(
|
||||
id=upload_file_id,
|
||||
tenant_id="tenant-id",
|
||||
name="diagram.png",
|
||||
extension="png",
|
||||
mime_type="image/png",
|
||||
source_url="https://example.com/diagram.png",
|
||||
key="storage-key",
|
||||
size=128,
|
||||
)
|
||||
attachment_session = MagicMock()
|
||||
attachment_session.execute.return_value.all.return_value = [(None, upload_file)]
|
||||
|
||||
class _AttachmentSessionContext:
|
||||
def __enter__(self):
|
||||
return attachment_session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
upload_session = MagicMock()
|
||||
upload_session.__enter__.return_value = upload_session
|
||||
upload_session.__exit__.return_value = False
|
||||
upload_session.scalar.return_value = upload_file
|
||||
|
||||
monkeypatch.setattr(node_runtime, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(node_runtime, "Session", MagicMock(return_value=_AttachmentSessionContext()))
|
||||
monkeypatch.setattr(file_builders, "session_factory", SimpleNamespace(create_session=lambda: upload_session))
|
||||
|
||||
loader = DifyRetrieverAttachmentLoader(file_reference_factory=DifyFileReferenceFactory(_build_run_context()))
|
||||
scope = FileAccessScope(
|
||||
tenant_id="tenant-id",
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
grant_retriever_segment_access([segment_id])
|
||||
files = loader.load(segment_id=segment_id)
|
||||
|
||||
assert files[0].related_id == upload_file_id
|
||||
stmt = upload_session.scalar.call_args.args[0]
|
||||
whereclause = str(stmt.whereclause)
|
||||
assert "upload_files.tenant_id" in whereclause
|
||||
assert "upload_files.id IN" in whereclause
|
||||
|
||||
|
||||
def test_dify_retriever_attachment_loader_skips_ungranted_segment_for_end_user(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
build_from_mapping = MagicMock()
|
||||
session_factory = MagicMock()
|
||||
monkeypatch.setattr(node_runtime, "Session", session_factory)
|
||||
loader = DifyRetrieverAttachmentLoader(
|
||||
file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping)
|
||||
)
|
||||
scope = FileAccessScope(
|
||||
tenant_id="tenant-id",
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
files = loader.load(segment_id=str(uuid4()))
|
||||
|
||||
assert files == []
|
||||
session_factory.assert_not_called()
|
||||
build_from_mapping.assert_not_called()
|
||||
|
||||
|
||||
def test_dify_retriever_attachment_loader_skips_segment_rejected_by_checker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
segment_id = str(uuid4())
|
||||
build_from_mapping = MagicMock()
|
||||
session_factory = MagicMock()
|
||||
segment_access_checker = MagicMock(return_value=False)
|
||||
monkeypatch.setattr(node_runtime, "Session", session_factory)
|
||||
loader = DifyRetrieverAttachmentLoader(
|
||||
file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping),
|
||||
segment_access_checker=segment_access_checker,
|
||||
)
|
||||
scope = FileAccessScope(
|
||||
tenant_id="tenant-id",
|
||||
user_id="end-user-id",
|
||||
user_from=UserFrom.END_USER,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
with bind_file_access_scope(scope):
|
||||
grant_retriever_segment_access([segment_id])
|
||||
files = loader.load(segment_id=segment_id)
|
||||
|
||||
assert files == []
|
||||
segment_access_checker.assert_called_once_with(segment_id)
|
||||
session_factory.assert_not_called()
|
||||
build_from_mapping.assert_not_called()
|
||||
|
||||
|
||||
def test_dify_tool_file_manager_resolves_conversation_id_for_tool_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
create_file_by_raw = MagicMock(return_value=SimpleNamespace(id="tool-file-id"))
|
||||
manager_instance = SimpleNamespace(create_file_by_raw=create_file_by_raw)
|
||||
|
||||
Reference in New Issue
Block a user