import functools import logging import ssl from collections.abc import Callable from datetime import timedelta from typing import TYPE_CHECKING, Any, Union import redis from redis import RedisError from redis.backoff import ExponentialWithJitterBackoff # type: ignore from redis.cache import CacheConfig from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.retry import Retry from redis.sentinel import Sentinel from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock logger = logging.getLogger(__name__) class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global `redis_client` variable cannot be updated when a new Redis instance is returned by Sentinel. This class allows for deferred initialization of the Redis client, enabling the client to be re-initialized with a new instance when necessary. This is particularly useful in scenarios where the Redis instance may change dynamically, such as during a failover in a Sentinel-managed Redis setup. Attributes: _client: The actual Redis client instance. It remains None until initialized with the `initialize` method. Methods: initialize(client): Initializes the Redis client if it hasn't been initialized already. __getattr__(item): Delegates attribute access to the Redis client, raising an error if the client is not initialized. """ _client: Union[redis.Redis, RedisCluster, None] def __init__(self) -> None: self._client = None def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None: if self._client is None: self._client = client if TYPE_CHECKING: # Type hints for IDE support and static analysis # These are not executed at runtime but provide type information def get(self, name: str | bytes) -> Any: ... def set( self, name: str | bytes, value: Any, ex: int | None = None, px: int | None = None, nx: bool = False, xx: bool = False, keepttl: bool = False, get: bool = False, exat: int | None = None, pxat: int | None = None, ) -> Any: ... def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... def setnx(self, name: str | bytes, value: Any) -> Any: ... def delete(self, *names: str | bytes) -> Any: ... def incr(self, name: str | bytes, amount: int = 1) -> Any: ... def expire( self, name: str | bytes, time: int | timedelta, nx: bool = False, xx: bool = False, gt: bool = False, lt: bool = False, ) -> Any: ... def lock( self, name: str, timeout: float | None = None, sleep: float = 0.1, blocking: bool = True, blocking_timeout: float | None = None, thread_local: bool = True, ) -> Lock: ... def zadd( self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes], nx: bool = False, xx: bool = False, ch: bool = False, incr: bool = False, gt: bool = False, lt: bool = False, ) -> Any: ... def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... def zcard(self, name: str | bytes) -> Any: ... def getdel(self, name: str | bytes) -> Any: ... def pubsub(self) -> PubSub: ... def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... def __getattr__(self, item: str) -> Any: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") return getattr(self._client, item) redis_client: RedisClientWrapper = RedisClientWrapper() _pubsub_redis_client: redis.Redis | RedisCluster | None = None class RedisSSLParamsDict(TypedDict): ssl_cert_reqs: int ssl_ca_certs: str | None ssl_certfile: str | None ssl_keyfile: str | None class RedisHealthParamsDict(TypedDict): retry: Retry socket_timeout: float | None socket_connect_timeout: float | None health_check_interval: int | None class RedisClusterHealthParamsDict(TypedDict): retry: Retry socket_timeout: float | None socket_connect_timeout: float | None class RedisBaseParamsDict(TypedDict): username: str | None password: str | None db: int encoding: str encoding_errors: str decode_responses: bool protocol: int cache_config: CacheConfig | None retry: Retry socket_timeout: float | None socket_connect_timeout: float | None health_check_interval: int | None def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: """Get SSL configuration for Redis connection.""" if not dify_config.REDIS_USE_SSL: return Connection, {} cert_reqs_map = { "CERT_NONE": ssl.CERT_NONE, "CERT_OPTIONAL": ssl.CERT_OPTIONAL, "CERT_REQUIRED": ssl.CERT_REQUIRED, } ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) ssl_kwargs = { "ssl_cert_reqs": ssl_cert_reqs, "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, } return SSLConnection, ssl_kwargs def _get_cache_configuration() -> CacheConfig | None: """Get client-side cache configuration if enabled.""" if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: return None resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL if resp_protocol < 3: raise ValueError("Client side cache is only supported in RESP3") return CacheConfig() def _get_retry_policy() -> Retry: """Build the shared retry policy for Redis connections.""" return Retry( backoff=ExponentialWithJitterBackoff( base=dify_config.REDIS_RETRY_BACKOFF_BASE, cap=dify_config.REDIS_RETRY_BACKOFF_CAP, ), retries=dify_config.REDIS_RETRY_RETRIES, ) def _get_connection_health_params() -> RedisHealthParamsDict: """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" return RedisHealthParamsDict( retry=_get_retry_policy(), socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT, socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL, ) def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict: """Get retry and timeout parameters for Redis Cluster clients. RedisCluster does not support ``health_check_interval`` as a constructor keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` are passed through. """ health_params = _get_connection_health_params() result: RedisClusterHealthParamsDict = { "retry": health_params["retry"], "socket_timeout": health_params["socket_timeout"], "socket_connect_timeout": health_params["socket_connect_timeout"], } return result def _get_base_redis_params() -> RedisBaseParamsDict: """Get base Redis connection parameters including retry and health policy.""" return RedisBaseParamsDict( username=dify_config.REDIS_USERNAME, password=dify_config.REDIS_PASSWORD or None, db=dify_config.REDIS_DB, encoding="utf-8", encoding_errors="strict", decode_responses=False, protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, cache_config=_get_cache_configuration(), **_get_connection_health_params(), ) def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create Redis client using Sentinel configuration.""" if not dify_config.REDIS_SENTINELS: raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") if not dify_config.REDIS_SENTINEL_SERVICE_NAME: raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True") sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] sentinel_kwargs = { "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, "username": dify_config.REDIS_SENTINEL_USERNAME, "password": dify_config.REDIS_SENTINEL_PASSWORD, } if dify_config.REDIS_MAX_CONNECTIONS: sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS sentinel = Sentinel( sentinel_hosts, sentinel_kwargs=sentinel_kwargs, ) params: dict[str, Any] = {**redis_params} master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params) return master def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: """Create Redis cluster client.""" if not dify_config.REDIS_CLUSTERS: raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True") nodes = [ ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) for node in dify_config.REDIS_CLUSTERS.split(",") ] cluster_kwargs: dict[str, Any] = { "startup_nodes": nodes, "password": dify_config.REDIS_CLUSTERS_PASSWORD, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), **_get_cluster_connection_health_params(), } if dify_config.REDIS_MAX_CONNECTIONS: cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS cluster: RedisCluster = RedisCluster(**cluster_kwargs) return cluster def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() params: dict[str, Any] = { **redis_params, "host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT, "connection_class": connection_class, } if dify_config.REDIS_MAX_CONNECTIONS: params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS if ssl_kwargs: params.update(ssl_kwargs) pool = redis.ConnectionPool(**params) client: redis.Redis = redis.Redis(connection_pool=pool) return client def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: max_conns = dify_config.REDIS_MAX_CONNECTIONS if use_clusters: health_params = _get_cluster_connection_health_params() kwargs: dict[str, Any] = {**health_params} if max_conns: kwargs["max_connections"] = max_conns return RedisCluster.from_url(pubsub_url, **kwargs) standalone_health_params: dict[str, Any] = dict(_get_connection_health_params()) kwargs = {**standalone_health_params} if max_conns: kwargs["max_connections"] = max_conns return redis.Redis.from_url(pubsub_url, **kwargs) def init_app(app: DifyApp): """Initialize Redis client and attach it to the app.""" global redis_client # Determine Redis mode and create appropriate client if dify_config.REDIS_USE_SENTINEL: redis_params = _get_base_redis_params() client = _create_sentinel_client(redis_params) elif dify_config.REDIS_USE_CLUSTERS: client = _create_cluster_client() else: redis_params = _get_base_redis_params() client = _create_standalone_client(redis_params) # Initialize the wrapper and attach to app redis_client.initialize(client) app.extensions["redis"] = redis_client global _pubsub_redis_client _pubsub_redis_client = client if dify_config.normalized_pubsub_redis_url: _pubsub_redis_client = _create_pubsub_client( dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS ) def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": return ShardedRedisBroadcastChannel(_pubsub_redis_client) if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": return StreamsBroadcastChannel( _pubsub_redis_client, retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, ) return RedisBroadcastChannel(_pubsub_redis_client) def redis_fallback[T](default_return: T | None = None): # type: ignore """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. Args: default_return: The value to return when a Redis operation fails. Defaults to None. """ def decorator[**P, R](func: Callable[P, R]) -> Callable[P, R | T | None]: @functools.wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None: try: return func(*args, **kwargs) except RedisError as e: func_name = getattr(func, "__name__", "Unknown") logger.warning("Redis operation failed in %s: %s", func_name, str(e), exc_info=True) return default_return return wrapper return decorator