mirror of
https://github.com/langgenius/dify.git
synced 2026-02-12 22:01:20 -05:00
Added support for an optional `download_filename` parameter in the `get_download_url` and `get_download_urls` methods across various storage classes. This allows users to specify a custom filename for downloads, improving user experience by enabling better file naming during downloads. Updated related methods and tests to accommodate this new functionality.
210 lines
7.8 KiB
Python
210 lines
7.8 KiB
Python
"""Storage wrapper that caches presigned download URLs."""
|
|
|
|
import hashlib
|
|
import logging
|
|
from itertools import starmap
|
|
|
|
from extensions.ext_redis import redis_client
|
|
from extensions.storage.base_storage import BaseStorage
|
|
from extensions.storage.storage_wrapper import StorageWrapper
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CachedPresignStorage(StorageWrapper):
|
|
"""Storage wrapper that caches presigned download URLs.
|
|
|
|
Wraps a storage with presign capability and caches the generated URLs
|
|
in Redis to reduce repeated presign API calls.
|
|
|
|
Example:
|
|
cached_storage = CachedPresignStorage(
|
|
storage=FilePresignStorage(base_storage),
|
|
cache_key_prefix="app_asset:draft_download",
|
|
)
|
|
url = cached_storage.get_download_url("path/to/file.txt", expires_in=3600)
|
|
"""
|
|
|
|
TTL_BUFFER_SECONDS = 60
|
|
MIN_TTL_SECONDS = 60
|
|
|
|
def __init__(
|
|
self,
|
|
storage: BaseStorage,
|
|
cache_key_prefix: str = "presign_cache",
|
|
):
|
|
super().__init__(storage)
|
|
self._redis = redis_client
|
|
self._cache_key_prefix = cache_key_prefix
|
|
|
|
def delete(self, filename: str):
|
|
super().delete(filename)
|
|
self.invalidate([filename])
|
|
|
|
def get_download_url(
|
|
self,
|
|
filename: str,
|
|
expires_in: int = 3600,
|
|
*,
|
|
download_filename: str | None = None,
|
|
) -> str:
|
|
"""Get a presigned download URL, using cache when available.
|
|
|
|
Args:
|
|
filename: The file path/key in storage
|
|
expires_in: URL validity duration in seconds (default: 1 hour)
|
|
download_filename: If provided, the browser will use this as the downloaded
|
|
file name. Cache keys include this value to avoid conflicts.
|
|
|
|
Returns:
|
|
Presigned URL string
|
|
"""
|
|
cache_key = self._cache_key(filename, download_filename)
|
|
|
|
cached = self._get_cached(cache_key)
|
|
if cached:
|
|
return cached
|
|
|
|
url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
|
self._set_cached(cache_key, url, expires_in)
|
|
|
|
return url
|
|
|
|
def get_download_urls(
|
|
self,
|
|
filenames: list[str],
|
|
expires_in: int = 3600,
|
|
*,
|
|
download_filenames: list[str] | None = None,
|
|
) -> list[str]:
|
|
"""Batch get download URLs with cache.
|
|
|
|
Args:
|
|
filenames: List of file paths/keys in storage
|
|
expires_in: URL validity duration in seconds (default: 1 hour)
|
|
download_filenames: If provided, must match len(filenames). Each element
|
|
specifies the download filename for the corresponding file.
|
|
|
|
Returns:
|
|
List of presigned URLs in the same order as filenames
|
|
"""
|
|
if not filenames:
|
|
return []
|
|
|
|
# Build cache keys including download_filename for uniqueness
|
|
if download_filenames is None:
|
|
cache_keys = [self._cache_key(f, None) for f in filenames]
|
|
else:
|
|
cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True)))
|
|
|
|
cached_values = self._get_cached_batch(cache_keys)
|
|
|
|
# Build results list, tracking which indices need fetching
|
|
results: list[str | None] = list(cached_values)
|
|
uncached_indices: list[int] = []
|
|
uncached_filenames: list[str] = []
|
|
uncached_download_filenames: list[str | None] = []
|
|
|
|
for i, (filename, cached) in enumerate(zip(filenames, cached_values)):
|
|
if not cached:
|
|
uncached_indices.append(i)
|
|
uncached_filenames.append(filename)
|
|
uncached_download_filenames.append(download_filenames[i] if download_filenames else None)
|
|
|
|
# Batch fetch uncached URLs from storage
|
|
if uncached_filenames:
|
|
uncached_urls = [
|
|
self._storage.get_download_url(f, expires_in, download_filename=df)
|
|
for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True)
|
|
]
|
|
|
|
# Fill results at correct positions
|
|
for idx, url in zip(uncached_indices, uncached_urls):
|
|
results[idx] = url
|
|
|
|
# Batch set cache
|
|
uncached_cache_keys = [cache_keys[i] for i in uncached_indices]
|
|
self._set_cached_batch(uncached_cache_keys, uncached_urls, expires_in)
|
|
|
|
return results # type: ignore[return-value]
|
|
|
|
def invalidate(self, filenames: list[str]) -> None:
|
|
"""Invalidate cached URLs for given filenames.
|
|
|
|
Args:
|
|
filenames: List of file paths/keys to invalidate
|
|
"""
|
|
if not filenames:
|
|
return
|
|
|
|
cache_keys = [self._cache_key(f) for f in filenames]
|
|
try:
|
|
self._redis.delete(*cache_keys)
|
|
except Exception:
|
|
logger.warning("Failed to invalidate presign cache", exc_info=True)
|
|
|
|
def _cache_key(self, filename: str, download_filename: str | None = None) -> str:
|
|
"""Generate cache key for a filename.
|
|
|
|
When download_filename is provided, its hash is appended to the key to ensure
|
|
different download names for the same storage key get separate cache entries.
|
|
We use a hash (truncated MD5) instead of the raw string because:
|
|
- download_filename may contain special characters unsafe for Redis keys
|
|
- Hash collisions only cause a cache miss, no functional impact
|
|
"""
|
|
if download_filename:
|
|
# Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness
|
|
name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16]
|
|
return f"{self._cache_key_prefix}:{filename}::{name_hash}"
|
|
return f"{self._cache_key_prefix}:{filename}"
|
|
|
|
def _compute_ttl(self, expires_in: int) -> int:
|
|
"""Compute cache TTL from presign expiration.
|
|
|
|
Returns TTL slightly shorter than presign expiry to ensure
|
|
cached URLs are refreshed before they expire.
|
|
"""
|
|
return max(expires_in - self.TTL_BUFFER_SECONDS, self.MIN_TTL_SECONDS)
|
|
|
|
def _get_cached(self, cache_key: str) -> str | None:
|
|
"""Get a single cached URL."""
|
|
try:
|
|
values = self._redis.mget([cache_key])
|
|
cached = values[0] if values else None
|
|
if cached:
|
|
return cached.decode("utf-8") if isinstance(cached, (bytes, bytearray)) else cached
|
|
return None
|
|
except Exception:
|
|
logger.warning("Failed to read presign cache", exc_info=True)
|
|
return None
|
|
|
|
def _get_cached_batch(self, cache_keys: list[str]) -> list[str | None]:
|
|
"""Get multiple cached URLs."""
|
|
try:
|
|
cached_values = self._redis.mget(cache_keys)
|
|
return [v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else v for v in cached_values]
|
|
except Exception:
|
|
logger.warning("Failed to read presign cache batch", exc_info=True)
|
|
return [None] * len(cache_keys)
|
|
|
|
def _set_cached(self, cache_key: str, url: str, expires_in: int) -> None:
|
|
"""Store a URL in cache with computed TTL."""
|
|
ttl = self._compute_ttl(expires_in)
|
|
try:
|
|
self._redis.setex(cache_key, ttl, url)
|
|
except Exception:
|
|
logger.warning("Failed to write presign cache", exc_info=True)
|
|
|
|
def _set_cached_batch(self, cache_keys: list[str], urls: list[str], expires_in: int) -> None:
|
|
"""Store multiple URLs in cache with computed TTL using pipeline."""
|
|
if not cache_keys:
|
|
return
|
|
ttl = self._compute_ttl(expires_in)
|
|
try:
|
|
pipe = self._redis.pipeline()
|
|
for cache_key, url in zip(cache_keys, urls):
|
|
pipe.setex(cache_key, ttl, url)
|
|
pipe.execute()
|
|
except Exception:
|
|
logger.warning("Failed to write presign cache batch", exc_info=True)
|