refactor(workflow-file): move core.file to core.workflow.file (#32252)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
99
2026-02-16 22:38:19 +08:00
committed by GitHub
parent 6824eda1c6
commit 7656d514b9
120 changed files with 364 additions and 252 deletions

View File

@@ -0,0 +1,19 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import (
File,
FileUploadConfig,
ImageConfig,
)
__all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"ImageConfig",
]

View File

@@ -0,0 +1,11 @@
from typing import Any
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
# comparing `dify_model_identity` with constants throughout the codebase, extract
# this logic into a dedicated function. This would encapsulate the implementation
# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__"
def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY

View File

@@ -0,0 +1,57 @@
from enum import StrEnum
class FileType(StrEnum):
IMAGE = "image"
DOCUMENT = "document"
AUDIO = "audio"
VIDEO = "video"
CUSTOM = "custom"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(StrEnum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(StrEnum):
TYPE = "type"
SIZE = "size"
NAME = "name"
MIME_TYPE = "mime_type"
TRANSFER_METHOD = "transfer_method"
URL = "url"
EXTENSION = "extension"
RELATED_ID = "related_id"
class ArrayFileAttribute(StrEnum):
LENGTH = "length"

View File

@@ -0,0 +1,143 @@
from __future__ import annotations
import base64
from collections.abc import Mapping
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .runtime import get_workflow_file_runtime
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
return file.type.value
case FileAttribute.SIZE:
return file.size
case FileAttribute.NAME:
return file.filename
case FileAttribute.MIME_TYPE:
return file.mime_type
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return _to_url(file)
case FileAttribute.EXTENSION:
return file.extension
case FileAttribute.RELATED_ID:
return file.related_id
def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
"""Convert a file to prompt message content."""
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
if f.type not in prompt_class_map:
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
send_format = get_workflow_file_runtime().multimodal_send_format
params = {
"base64_data": _get_encoded_string(f) if send_format == "base64" else "",
"url": _to_url(f) if send_format == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
return prompt_class_map[f.type].model_validate(params)
def download(f: File, /) -> bytes:
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(path: str, /) -> bytes:
"""Download and return a file from storage as bytes."""
data = get_workflow_file_runtime().storage_load(path, stream=False)
if not isinstance(data, bytes):
raise ValueError(f"file {path} is not a bytes object")
return data
def _get_encoded_string(f: File, /) -> str:
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
data = _download_file_content(f.storage_key)
case FileTransferMethod.TOOL_FILE:
data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
return base64.b64encode(data).decode("utf-8")
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
return f.remote_url
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
class FileManager:
"""Adapter exposing file manager helpers behind FileManagerProtocol."""
def download(self, f: File, /) -> bytes:
return download(f)
file_manager = FileManager()

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
import base64
import hashlib
import hmac
import os
import time
import urllib.parse
from .runtime import get_workflow_file_runtime
def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str:
runtime = get_workflow_file_runtime()
base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url)
url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = runtime.secret_key.encode()
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign}
if as_attachment:
query["as_attachment"] = "true"
query_string = urllib.parse.urlencode(query)
return f"{url}?{query_string}"
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
runtime = get_workflow_file_runtime()
# Plugin access should use internal URL for Docker network communication.
base_url = runtime.internal_files_url or runtime.files_url
url = f"{base_url}/files/upload/for-plugin"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
key = runtime.secret_key.encode()
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str:
runtime = get_workflow_file_runtime()
return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
def verify_plugin_file_signature(
*, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str
) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
runtime = get_workflow_file_runtime()
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = runtime.secret_key.encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= runtime.files_access_timeout

View File

@@ -0,0 +1,178 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``."""
return helpers.get_signed_tool_file_url(
tool_file_id=tool_file_id,
extension=extension,
for_external=for_external,
)
class ImageConfig(BaseModel):
"""
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
"""
number_limits: int = 0
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
detail: ImagePromptMessageContent.DETAIL | None = None
class FileUploadConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: ImageConfig | None = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0
class File(BaseModel):
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.
dify_model_identity: str = FILE_MODEL_IDENTITY
id: str | None = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
# If `transfer_method` is `FileTransferMethod.remote_url`, the
# `remote_url` attribute must not be `None`.
remote_url: str | None = None # remote url
# If `transfer_method` is `FileTransferMethod.local_file` or
# `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`.
#
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: str | None = None
filename: str | None = None
extension: str | None = Field(default=None, description="File extension, should contain dot")
mime_type: str | None = None
size: int = -1
# Those properties are private, should not be exposed to the outside.
_storage_key: str
def __init__(
self,
*,
id: str | None = None,
tenant_id: str,
type: FileType,
transfer_method: FileTransferMethod,
remote_url: str | None = None,
related_id: str | None = None,
filename: str | None = None,
extension: str | None = None,
mime_type: str | None = None,
size: int = -1,
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly handle known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
super().__init__(
id=id,
tenant_id=tenant_id,
type=type,
transfer_method=transfer_method,
remote_url=remote_url,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
dify_model_identity=dify_model_identity,
url=url,
)
self._storage_key = str(storage_key)
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {
**data,
"url": self.generate_url(),
}
@property
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f"![{self.filename or ''}]({url})"
else:
text = f"[{self.filename or url}]({url})"
return text
def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(
tool_file_id=self.related_id,
extension=self.extension,
for_external=for_external,
)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"mime_type": self.mime_type,
"filename": self.filename,
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(for_external=False),
}
@model_validator(mode="after")
def validate_after(self) -> File:
match self.transfer_method:
case FileTransferMethod.REMOTE_URL:
if not self.remote_url:
raise ValueError("Missing file url")
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
raise ValueError("Invalid file url")
case FileTransferMethod.LOCAL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
case FileTransferMethod.TOOL_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
case FileTransferMethod.DATASOURCE_FILE:
if not self.related_id:
raise ValueError("Missing file related_id")
return self
@property
def storage_key(self) -> str:
return self._storage_key
@storage_key.setter
def storage_key(self, value: str) -> None:
self._storage_key = value

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Protocol
class HttpResponseProtocol(Protocol):
"""Subset of response behavior needed by workflow file helpers."""
@property
def content(self) -> bytes: ...
def raise_for_status(self) -> object: ...
class WorkflowFileRuntimeProtocol(Protocol):
"""Runtime dependencies required by ``core.workflow.file``.
Implementations are expected to be provided by integration layers (for example,
``core.app.workflow.file_runtime``) so the workflow package avoids importing
application infrastructure modules directly.
"""
@property
def files_url(self) -> str: ...
@property
def internal_files_url(self) -> str | None: ...
@property
def secret_key(self) -> str: ...
@property
def files_access_timeout(self) -> int: ...
@property
def multimodal_send_format(self) -> str: ...
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ...
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ...

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from collections.abc import Generator
from typing import NoReturn
from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
class WorkflowFileRuntimeNotConfiguredError(RuntimeError):
"""Raised when workflow file runtime dependencies were not configured."""
class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
def _raise(self) -> NoReturn:
raise WorkflowFileRuntimeNotConfiguredError(
"workflow file runtime is not configured, call set_workflow_file_runtime(...) first"
)
@property
def files_url(self) -> str:
self._raise()
@property
def internal_files_url(self) -> str | None:
self._raise()
@property
def secret_key(self) -> str:
self._raise()
@property
def files_access_timeout(self) -> int:
self._raise()
@property
def multimodal_send_format(self) -> str:
self._raise()
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
self._raise()
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
self._raise()
def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._raise()
_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime()
def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None:
global _runtime
_runtime = runtime
def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol:
return _runtime

View File

@@ -0,0 +1,9 @@
from collections.abc import Callable
from typing import Any
_tool_file_manager_factory: Callable[[], Any] | None = None
def set_tool_file_manager_factory(factory: Callable[[], Any]):
global _tool_file_manager_factory
_tool_file_manager_factory = factory