Files
dify/api/core/sandbox/manager.py
Harry bb4dd85ae3 feat(sandbox): refactor sandbox file handling to include app_id
- Updated API routes to use app_id instead of sandbox_id for file operations, aligning with user-specific sandbox workspaces.
- Enhanced SandboxFileService and related classes to accommodate app_id in file listing and download functionalities.
- Refactored storage key generation for sandbox archives to include app_id, ensuring proper file organization.
- Adjusted frontend contracts and services to reflect the new app_id parameter in API calls.
2026-01-30 22:45:28 +08:00

104 lines
3.2 KiB
Python

from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from core.sandbox.sandbox import Sandbox
logger = logging.getLogger(__name__)
class SandboxManager:
"""Registry for active Sandbox instances.
Stores complete Sandbox objects (not just VirtualEnvironment) to provide
access to sandbox metadata like tenant_id, app_id, user_id, assets_id.
"""
_NUM_SHARDS: Final[int] = 1024
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
_shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)]
@classmethod
def _shard_index(cls, sandbox_id: str) -> int:
return hash(sandbox_id) & cls._SHARD_MASK
@classmethod
def register(cls, sandbox_id: str, sandbox: Sandbox) -> None:
if not sandbox_id:
raise ValueError("sandbox_id cannot be empty")
shard_index = cls._shard_index(sandbox_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
if sandbox_id in shard:
raise RuntimeError(
f"Sandbox already registered for sandbox_id={sandbox_id}. "
"Call unregister() first if you need to replace it."
)
new_shard = dict(shard)
new_shard[sandbox_id] = sandbox
cls._shards[shard_index] = new_shard
logger.debug(
"Registered sandbox: sandbox_id=%s, vm_id=%s, app_id=%s",
sandbox_id,
sandbox.vm.metadata.id,
sandbox.app_id,
)
@classmethod
def get(cls, sandbox_id: str) -> Sandbox | None:
shard_index = cls._shard_index(sandbox_id)
return cls._shards[shard_index].get(sandbox_id)
@classmethod
def unregister(cls, sandbox_id: str) -> Sandbox | None:
shard_index = cls._shard_index(sandbox_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
sandbox = shard.get(sandbox_id)
if sandbox is None:
return None
new_shard = dict(shard)
new_shard.pop(sandbox_id, None)
cls._shards[shard_index] = new_shard
logger.debug(
"Unregistered sandbox: sandbox_id=%s, vm_id=%s",
sandbox_id,
sandbox.vm.metadata.id,
)
return sandbox
@classmethod
def has(cls, sandbox_id: str) -> bool:
shard_index = cls._shard_index(sandbox_id)
return sandbox_id in cls._shards[shard_index]
@classmethod
def is_sandbox_runtime(cls, sandbox_id: str) -> bool:
return cls.has(sandbox_id)
@classmethod
def clear(cls) -> None:
for lock in cls._shard_locks:
lock.acquire()
try:
for i in range(cls._NUM_SHARDS):
cls._shards[i] = {}
logger.debug("Cleared all registered sandboxes")
finally:
for lock in reversed(cls._shard_locks):
lock.release()
@classmethod
def count(cls) -> int:
return sum(len(shard) for shard in cls._shards)