diff --git a/api/controllers/cli_api/plugin/plugin.py b/api/controllers/cli_api/plugin/plugin.py index 0fd66230ba..e1ac654fdd 100644 --- a/api/controllers/cli_api/plugin/plugin.py +++ b/api/controllers/cli_api/plugin/plugin.py @@ -1,3 +1,4 @@ +from flask import abort from flask_restx import Resource from controllers.cli_api import cli_api_ns @@ -15,6 +16,8 @@ from core.plugin.entities.request import ( RequestInvokeTool, RequestRequestUploadFile, ) +from core.session.cli_api import CliContext +from core.skill.entities import ToolInvocationRequest from core.tools.entities.tool_entities import ToolProviderType from libs.helper import length_prefixed_response from models import Account, Tenant @@ -23,9 +26,9 @@ from models.model import EndUser @cli_api_ns.route("/invoke/llm") class CliInvokeLLMApi(Resource): + @cli_api_only @get_cli_user_tenant @setup_required - @cli_api_only @plugin_data(payload_type=RequestInvokeLLM) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): def generator(): @@ -37,17 +40,34 @@ class CliInvokeLLMApi(Resource): @cli_api_ns.route("/invoke/tool") class CliInvokeToolApi(Resource): + @cli_api_only @get_cli_user_tenant @setup_required - @cli_api_only @plugin_data(payload_type=RequestInvokeTool) - def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): + def post( + self, + user_model: Account | EndUser, + tenant_model: Tenant, + payload: RequestInvokeTool, + cli_context: CliContext, + ): + tool_type = ToolProviderType.value_of(payload.tool_type) + + request = ToolInvocationRequest( + tool_type=tool_type, + provider=payload.provider, + tool_name=payload.tool, + credential_id=payload.credential_id, + ) + if cli_context.tool_access and not cli_context.tool_access.is_allowed(request): + abort(403) + def generator(): return PluginToolBackwardsInvocation.convert_to_event_stream( PluginToolBackwardsInvocation.invoke_tool( tenant_id=tenant_model.id, user_id=user_model.id, - tool_type=ToolProviderType.value_of(payload.tool_type), + tool_type=tool_type, provider=payload.provider, tool_name=payload.tool, tool_parameters=payload.tool_parameters, @@ -60,9 +80,9 @@ class CliInvokeToolApi(Resource): @cli_api_ns.route("/invoke/app") class CliInvokeAppApi(Resource): + @cli_api_only @get_cli_user_tenant @setup_required - @cli_api_only @plugin_data(payload_type=RequestInvokeApp) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): response = PluginAppBackwardsInvocation.invoke_app( @@ -81,9 +101,9 @@ class CliInvokeAppApi(Resource): @cli_api_ns.route("/upload/file/request") class CliUploadFileRequestApi(Resource): + @cli_api_only @get_cli_user_tenant @setup_required - @cli_api_only @plugin_data(payload_type=RequestRequestUploadFile) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): # generate signed url @@ -98,9 +118,9 @@ class CliUploadFileRequestApi(Resource): @cli_api_ns.route("/fetch/tools/list") class CliFetchToolsListApi(Resource): + @cli_api_only @get_cli_user_tenant @setup_required - @cli_api_only def post(self, user_model: Account | EndUser, tenant_model: Tenant): from sqlalchemy.orm import Session diff --git a/api/controllers/cli_api/plugin/wraps.py b/api/controllers/cli_api/plugin/wraps.py index 1599637395..4b9f542e59 100644 --- a/api/controllers/cli_api/plugin/wraps.py +++ b/api/controllers/cli_api/plugin/wraps.py @@ -2,12 +2,12 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar -from flask import current_app, request +from flask import current_app, g, request from flask_login import user_logged_in from pydantic import BaseModel from sqlalchemy.orm import Session -from core.session.cli_api import CliApiSession, CliApiSessionManager +from core.session.cli_api import CliApiSession, CliContext from extensions.ext_database import db from libs.login import current_user from models.account import Tenant @@ -75,22 +75,13 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_cli_user_tenant(view_func: Callable[P, R]): @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): - session_id = request.headers.get("X-Cli-Api-Session-Id") + session: CliApiSession | None = getattr(g, "cli_api_session", None) + if session is None: + raise ValueError("session not found") - if session_id: - session: CliApiSession | None = CliApiSessionManager().get(session_id) - if not session: - raise ValueError("session not found") - user_id = session.user_id - tenant_id = session.tenant_id - - else: - payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) - user_id = payload.user_id - tenant_id = payload.tenant_id - - if not tenant_id: - raise ValueError("tenant_id is required") + user_id = session.user_id + tenant_id = session.tenant_id + cli_context = CliContext.model_validate(session.context) if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -110,11 +101,10 @@ def get_cli_user_tenant(view_func: Callable[P, R]): raise ValueError("tenant not found") kwargs["tenant_model"] = tenant_model + kwargs["user_model"] = get_user(tenant_id, user_id) + kwargs["cli_context"] = cli_context - user = get_user(tenant_id, user_id) - kwargs["user_model"] = user - - current_app.login_manager._update_request_context_with_user(user) # type: ignore + current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore return view_func(*args, **kwargs) diff --git a/api/controllers/cli_api/wraps.py b/api/controllers/cli_api/wraps.py index c75022fa2c..d4f1cdb522 100644 --- a/api/controllers/cli_api/wraps.py +++ b/api/controllers/cli_api/wraps.py @@ -5,7 +5,7 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar -from flask import abort, request +from flask import abort, g, request from core.session.cli_api import CliApiSessionManager @@ -49,6 +49,8 @@ def cli_api_only(view: Callable[P, R]): if not _verify_signature(session.secret, timestamp, body, signature): abort(401) + g.cli_api_session = session + return view(*args, **kwargs) return decorated diff --git a/api/core/sandbox/bash/TODO.md b/api/core/sandbox/bash/TODO.md deleted file mode 100644 index 95594f13da..0000000000 --- a/api/core/sandbox/bash/TODO.md +++ /dev/null @@ -1 +0,0 @@ -# refactor the package import paths \ No newline at end of file diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index 463fc97e94..f29c8f2df4 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -6,7 +6,8 @@ from io import BytesIO from types import TracebackType from core.sandbox.sandbox import Sandbox -from core.session.cli_api import CliApiSession, CliApiSessionManager +from core.session.cli_api import CliApiSession, CliApiSessionManager, CliContext +from core.skill.entities import ToolAccessPolicy from core.skill.entities.tool_dependencies import ToolDependencies from core.virtual_environment.__base.helpers import pipeline @@ -37,6 +38,7 @@ class SandboxBashSession: self._cli_api_session = CliApiSessionManager().create( tenant_id=self._tenant_id, user_id=self._user_id, + context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(self._tools)), ) if self._tools is not None and not self._tools.is_empty(): tools_path = self._setup_node_tools_directory(self._node_id, self._tools, self._cli_api_session) @@ -55,7 +57,7 @@ class SandboxBashSession: node_id: str, tools: ToolDependencies, cli_api_session: CliApiSession, - ) -> str | None: + ) -> str: node_tools_path = f"{DifyCli.TOOLS_ROOT}/{node_id}" vm = self._sandbox.vm diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 994654ab6a..49f013eca5 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -6,7 +6,8 @@ from io import BytesIO from pathlib import Path from core.sandbox.sandbox import Sandbox -from core.session.cli_api import CliApiSessionManager +from core.session.cli_api import CliApiSessionManager, CliContext +from core.skill.entities import ToolAccessPolicy from core.skill.skill_manager import SkillManager from core.virtual_environment.__base.helpers import pipeline @@ -63,7 +64,11 @@ class DifyCliInitializer(AsyncSandboxInitializer): logger.info("No tools found in bundle for assets_id=%s", self._assets_id) return - self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) + self._cli_api_session = CliApiSessionManager().create( + tenant_id=self._tenant_id, + user_id=self._user_id, + context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())), + ) pipeline(vm).add( ["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" diff --git a/api/core/session/cli_api.py b/api/core/session/cli_api.py index c9fdcb451f..44d13c55ae 100644 --- a/api/core/session/cli_api.py +++ b/api/core/session/cli_api.py @@ -1,7 +1,8 @@ import secrets -from typing import Any -from pydantic import Field +from pydantic import BaseModel, Field + +from core.skill.entities import ToolAccessPolicy from .session import BaseSession, SessionManager @@ -10,11 +11,15 @@ class CliApiSession(BaseSession): secret: str = Field(default_factory=lambda: secrets.token_urlsafe(32)) +class CliContext(BaseModel): + tool_access: ToolAccessPolicy | None = Field(default=None, description="Tool access policy") + + class CliApiSessionManager(SessionManager[CliApiSession]): def __init__(self, ttl: int | None = None): super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl) - def create(self, tenant_id: str, user_id: str, context: dict[str, Any] | None = None) -> CliApiSession: - session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context or {}) + def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession: + session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json")) self.save(session) return session diff --git a/api/core/skill/entities/__init__.py b/api/core/skill/entities/__init__.py index e953a97546..4d3613469c 100644 --- a/api/core/skill/entities/__init__.py +++ b/api/core/skill/entities/__init__.py @@ -9,6 +9,7 @@ from .skill_metadata import ( ToolFieldConfig, ToolReference, ) +from .tool_access_policy import ToolAccessPolicy, ToolInvocationRequest, ToolKey from .tool_dependencies import ToolDependencies, ToolDependency __all__ = [ @@ -19,9 +20,12 @@ __all__ = [ "SkillDocument", "SkillMetadata", "SourceInfo", + "ToolAccessPolicy", "ToolConfiguration", "ToolDependencies", "ToolDependency", "ToolFieldConfig", + "ToolInvocationRequest", + "ToolKey", "ToolReference", ] diff --git a/api/core/skill/entities/tool_access_policy.py b/api/core/skill/entities/tool_access_policy.py new file mode 100644 index 0000000000..dbe1df7c82 --- /dev/null +++ b/api/core/skill/entities/tool_access_policy.py @@ -0,0 +1,87 @@ +from pydantic import BaseModel, ConfigDict, Field + +from core.skill.entities.tool_dependencies import ToolDependencies +from core.tools.entities.tool_entities import ToolProviderType + + +class ToolKey(BaseModel): + """Immutable identifier for a tool (type + provider + name).""" + + model_config = ConfigDict(frozen=True) + + tool_type: ToolProviderType + provider: str + tool_name: str + + +class ToolInvocationRequest(BaseModel): + """A request to invoke a specific tool with optional credential.""" + + model_config = ConfigDict(frozen=True) + + tool_type: ToolProviderType + provider: str + tool_name: str + credential_id: str | None = None + + @property + def key(self) -> ToolKey: + return ToolKey(tool_type=self.tool_type, provider=self.provider, tool_name=self.tool_name) + + +class ToolAccessPolicy(BaseModel): + """ + Determines whether a tool invocation is allowed based on ToolDependencies. + + Rules: + 1. Tool must be declared in dependencies or references. + 2. If references exist for the tool, credential_id must match one of them. + 3. If no references exist for the tool, credential_id must be None. + """ + + model_config = ConfigDict(frozen=True) + + allowed_tools: frozenset[ToolKey] = Field(default_factory=frozenset) + credential_ids_by_tool: dict[ToolKey, frozenset[str | None]] = Field(default_factory=dict) + + @classmethod + def from_dependencies(cls, deps: ToolDependencies | None) -> "ToolAccessPolicy": + if deps is None or deps.is_empty(): + return cls() + + def to_key(t: ToolProviderType, p: str, n: str) -> ToolKey: + return ToolKey(tool_type=t, provider=p, tool_name=n) + + tools: set[ToolKey] = set() + tools.update(to_key(dep.type, dep.provider, dep.tool_name) for dep in deps.dependencies) + tools.update(to_key(ref.type, ref.provider, ref.tool_name) for ref in deps.references) + + creds: dict[ToolKey, set[str | None]] = {} + for ref in deps.references: + key = to_key(ref.type, ref.provider, ref.tool_name) + creds.setdefault(key, set()).add(ref.credential_id) + + return cls( + allowed_tools=frozenset(tools), + credential_ids_by_tool={k: frozenset(v) for k, v in creds.items()}, + ) + + def is_empty(self) -> bool: + return len(self.allowed_tools) == 0 + + def is_allowed(self, request: ToolInvocationRequest) -> bool: + """Check if the tool invocation request is allowed.""" + + # If the policy is empty, allow any invocation. + if self.is_empty(): + return True + + if request.key not in self.allowed_tools: + return False + + allowed_credentials = self.credential_ids_by_tool.get(request.key) + if not allowed_credentials: + # No references for this tool: only allow invocation without credential. + return request.credential_id is None + + return request.credential_id in allowed_credentials