diff --git a/api/core/agent/sandbox_session.py b/api/core/agent/sandbox_session.py new file mode 100644 index 0000000000..833e9cbcf4 --- /dev/null +++ b/api/core/agent/sandbox_session.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import json +import logging +from io import BytesIO +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import TracebackType + + from core.tools.__base.tool import Tool + from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + +logger = logging.getLogger(__name__) + + +class SandboxSession: + def __init__( + self, + *, + workflow_execution_id: str, + tenant_id: str, + user_id: str, + tools: list[Tool], + ) -> None: + self._workflow_execution_id = workflow_execution_id + self._tenant_id = tenant_id + self._user_id = user_id + self._tools = tools + + self._sandbox: VirtualEnvironment | None = None + self._bash_tool: SandboxBashTool | None = None + self._session_id: str | None = None + + def __enter__(self) -> SandboxSession: + from core.session.inner_api import InnerApiSessionManager + from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool + from core.virtual_environment.sandbox_manager import SandboxManager + + sandbox = SandboxManager.get(self._workflow_execution_id) + if sandbox is None: + raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}") + + session = InnerApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) + + try: + _upload_and_init_dify_cli(sandbox, self._tools, session.id) + except Exception: + InnerApiSessionManager().delete(session.id) + raise + + self._sandbox = sandbox + self._session_id = session.id + self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: + try: + self.cleanup() + except Exception: + logger.exception("Failed to cleanup SandboxSession") + return False + + @property + def bash_tool(self) -> SandboxBashTool: + if self._bash_tool is None: + raise RuntimeError("SandboxSession is not initialized") + return self._bash_tool + + def cleanup(self) -> None: + from core.session.inner_api import InnerApiSessionManager + + if self._session_id is None: + return + + InnerApiSessionManager().delete(self._session_id) + logger.debug("Cleaned up SandboxSession session_id=%s", self._session_id) + self._session_id = None + + +def _upload_and_init_dify_cli(sandbox: VirtualEnvironment, tools: list[Tool], session_id: str) -> None: + from configs import dify_config + + config = { + "env": { + "files_url": dify_config.FILES_URL, + "inner_api_url": dify_config.CONSOLE_API_URL, + "inner_api_session_id": session_id, + }, + "tools": _serialize_tools(tools), + } + + config_json = json.dumps(config, ensure_ascii=False) + config_path = f"/tmp/dify-init-{session_id}.json" + + sandbox.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) + + connection_handle = sandbox.establish_connection() + try: + future = sandbox.run_command(connection_handle, ["dify", "init", config_path]) + result = future.result(timeout=30) + if result.exit_code != 0: + stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else "" + raise RuntimeError(f"Failed to initialize Dify CLI in sandbox: {stderr}") + finally: + sandbox.release_connection(connection_handle) + + +def _serialize_tools(tools: list[Tool]) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + + for tool in tools: + tool_config = tool.entity.model_dump() + tool_config["provider_type"] = tool.tool_provider_type().value + tool_config["credential_type"] = tool.runtime.credential_type.value if tool.runtime else "default" + tool_config["credential_id"] = tool.runtime.tool_id if tool.runtime else tool.entity.identity.provider + result.append(tool_config) + + return result diff --git a/api/core/virtual_environment/sandbox_manager.py b/api/core/virtual_environment/sandbox_manager.py index 72eda64d0c..3de21e4e83 100644 --- a/api/core/virtual_environment/sandbox_manager.py +++ b/api/core/virtual_environment/sandbox_manager.py @@ -2,11 +2,7 @@ from __future__ import annotations import logging import threading -from typing import TYPE_CHECKING, Final - -if TYPE_CHECKING: - from core.tools.__base.tool import Tool - from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool +from typing import Final from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -108,28 +104,3 @@ class SandboxManager: @classmethod def count(cls) -> int: return sum(len(shard) for shard in cls._shards) - - @classmethod - def get_bash_tool( - cls, - workflow_execution_id: str, - tenant_id: str, - configured_tools: list[Tool], - ) -> SandboxBashTool: - from core.tools.builtin_tool.providers.sandbox.bash_tool import SandboxBashTool - - sandbox = cls.get(workflow_execution_id) - if sandbox is None: - raise RuntimeError(f"Sandbox not found for workflow_execution_id={workflow_execution_id}") - - cls._initialize_tools_in_sandbox(sandbox, configured_tools) - - return SandboxBashTool(sandbox=sandbox, tenant_id=tenant_id) - - @classmethod - def _initialize_tools_in_sandbox( - cls, - sandbox: VirtualEnvironment, - configured_tools: list[Tool], - ) -> None: - raise NotImplementedError("TODO: Initialize configured tools in sandbox environment") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index e306693a1d..d766d0a2e0 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -13,6 +13,7 @@ from sqlalchemy import select from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext from core.agent.patterns import StrategyFactory +from core.agent.sandbox_session import SandboxSession from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage @@ -264,9 +265,8 @@ class LLMNode(Node[LLMNodeData]): if self.tool_call_enabled: workflow_execution_id = variable_pool.system_variables.workflow_execution_id - is_sandbox_runtime = ( - workflow_execution_id is not None - and SandboxManager.is_sandbox_runtime(workflow_execution_id) + is_sandbox_runtime = workflow_execution_id is not None and SandboxManager.is_sandbox_runtime( + workflow_execution_id ) if is_sandbox_runtime: @@ -274,10 +274,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - files=files, variable_pool=variable_pool, - node_inputs=node_inputs, - process_data=process_data, ) else: generator = self._invoke_llm_with_tools( @@ -1589,10 +1586,7 @@ class LLMNode(Node[LLMNodeData]): model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, - files: Sequence[File], variable_pool: VariablePool, - node_inputs: dict[str, Any], - process_data: dict[str, Any], ) -> Generator[NodeEventBase, None, LLMGenerationData]: from core.agent.entities import AgentEntity @@ -1602,32 +1596,38 @@ class LLMNode(Node[LLMNodeData]): configured_tools = self._prepare_tool_instances(variable_pool) - bash_tool = SandboxManager.get_bash_tool( + result: LLMGenerationData | None = None + + with SandboxSession( workflow_execution_id=workflow_execution_id, tenant_id=self.tenant_id, - configured_tools=configured_tools, - ) + user_id=self.user_id, + tools=configured_tools, + ) as sandbox_session: + prompt_files = self._extract_prompt_files(variable_pool) - prompt_files = self._extract_prompt_files(variable_pool) + strategy = StrategyFactory.create_strategy( + model_features=[], + model_instance=model_instance, + tools=[sandbox_session.bash_tool], + files=prompt_files, + max_iterations=self._node_data.max_iterations or 10, + context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + ) - strategy = StrategyFactory.create_strategy( - model_features=[], - model_instance=model_instance, - tools=[bash_tool], - files=prompt_files, - max_iterations=self._node_data.max_iterations or 10, - context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), - agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, - ) + outputs = strategy.run( + prompt_messages=list(prompt_messages), + model_parameters=self._node_data.model.completion_params, + stop=list(stop or []), + stream=True, + ) - outputs = strategy.run( - prompt_messages=list(prompt_messages), - model_parameters=self._node_data.model.completion_params, - stop=list(stop or []), - stream=True, - ) + result = yield from self._process_tool_outputs(outputs) + + if result is None: + raise LLMNodeError("SandboxSession exited unexpectedly") - result = yield from self._process_tool_outputs(outputs) return result def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: diff --git a/api/models/model.py b/api/models/model.py index 2845b22c72..aab13bb60e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -2197,4 +2197,3 @@ class TenantCreditPool(Base): def has_sufficient_credits(self, required_credits: int) -> bool: return self.remaining_credits >= required_credits -