diff --git a/api/controllers/console/snippets/payloads.py b/api/controllers/console/snippets/payloads.py index ec657f2cf1..73492db8ae 100644 --- a/api/controllers/console/snippets/payloads.py +++ b/api/controllers/console/snippets/payloads.py @@ -101,3 +101,20 @@ class PublishWorkflowPayload(BaseModel): """Payload for publishing snippet workflow.""" knowledge_base_setting: dict[str, Any] | None = None + + +class SnippetImportPayload(BaseModel): + """Payload for importing snippet from DSL.""" + + mode: str = Field(..., description="Import mode: yaml-content or yaml-url") + yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)") + yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)") + name: str | None = Field(default=None, description="Override snippet name") + description: str | None = Field(default=None, description="Override snippet description") + snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)") + + +class IncludeSecretQuery(BaseModel): + """Query parameter for including secret variables in export.""" + + include_secret: str = Field(default="false", description="Whether to include secret variables") diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py index 41aaa5c3b0..51b66d2e55 100644 --- a/api/controllers/console/workspace/snippets.py +++ b/api/controllers/console/workspace/snippets.py @@ -9,6 +9,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.snippets.payloads import ( CreateSnippetPayload, + IncludeSecretQuery, + SnippetImportPayload, SnippetListQuery, UpdateSnippetPayload, ) @@ -21,6 +23,8 @@ from extensions.ext_database import db from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields from libs.login import current_account_with_tenant, login_required from models.snippet import SnippetType +from services.app_dsl_service import ImportStatus +from services.snippet_dsl_service import SnippetDslService from services.snippet_service import SnippetService logger = logging.getLogger(__name__) @@ -31,6 +35,8 @@ register_schema_models( SnippetListQuery, CreateSnippetPayload, UpdateSnippetPayload, + SnippetImportPayload, + IncludeSecretQuery, ) # Create namespace models for marshaling @@ -201,3 +207,132 @@ class CustomizedSnippetDetailApi(Resource): session.commit() return "", 204 + + +@console_ns.route("/workspaces/current/customized-snippets//export") +class CustomizedSnippetExportApi(Resource): + @console_ns.doc("export_customized_snippet") + @console_ns.doc(description="Export snippet configuration as DSL") + @console_ns.doc(params={"snippet_id": "Snippet ID to export"}) + @console_ns.response(200, "Snippet exported successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self, snippet_id: str): + """Export snippet as DSL.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + # Get include_secret parameter + query = IncludeSecretQuery.model_validate(request.args.to_dict()) + + with Session(db.engine) as session: + export_service = SnippetDslService(session) + result = export_service.export_snippet_dsl( + snippet=snippet, include_secret=query.include_secret == "true" + ) + + return {"data": result}, 200 + + +@console_ns.route("/workspaces/current/customized-snippets/imports") +class CustomizedSnippetImportApi(Resource): + @console_ns.doc("import_customized_snippet") + @console_ns.doc(description="Import snippet from DSL") + @console_ns.expect(console_ns.models.get(SnippetImportPayload.__name__)) + @console_ns.response(200, "Snippet imported successfully") + @console_ns.response(202, "Import pending confirmation") + @console_ns.response(400, "Import failed") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self): + """Import snippet from DSL.""" + current_user, _ = current_account_with_tenant() + payload = SnippetImportPayload.model_validate(console_ns.payload or {}) + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.import_snippet( + account=current_user, + import_mode=payload.mode, + yaml_content=payload.yaml_content, + yaml_url=payload.yaml_url, + snippet_id=payload.snippet_id, + name=payload.name, + description=payload.description, + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@console_ns.route("/workspaces/current/customized-snippets/imports//confirm") +class CustomizedSnippetImportConfirmApi(Resource): + @console_ns.doc("confirm_snippet_import") + @console_ns.doc(description="Confirm a pending snippet import") + @console_ns.doc(params={"import_id": "Import ID to confirm"}) + @console_ns.response(200, "Import confirmed successfully") + @console_ns.response(400, "Import failed") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self, import_id: str): + """Confirm a pending snippet import.""" + current_user, _ = current_account_with_tenant() + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.confirm_import(import_id=import_id, account=current_user) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 + + +@console_ns.route("/workspaces/current/customized-snippets//check-dependencies") +class CustomizedSnippetCheckDependenciesApi(Resource): + @console_ns.doc("check_snippet_dependencies") + @console_ns.doc(description="Check dependencies for a snippet") + @console_ns.doc(params={"snippet_id": "Snippet ID"}) + @console_ns.response(200, "Dependencies checked successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def get(self, snippet_id: str): + """Check dependencies for a snippet.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + with Session(db.engine) as session: + import_service = SnippetDslService(session) + result = import_service.check_dependencies(snippet=snippet) + + return result.model_dump(mode="json"), 200 diff --git a/api/services/snippet_dsl_service.py b/api/services/snippet_dsl_service.py new file mode 100644 index 0000000000..fb2a75c5d3 --- /dev/null +++ b/api/services/snippet_dsl_service.py @@ -0,0 +1,541 @@ +import json +import logging +import uuid +from collections.abc import Mapping +from datetime import UTC, datetime +from enum import StrEnum +from urllib.parse import urlparse + +import yaml # type: ignore +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import PluginDependency +from core.workflow.enums import NodeType +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account +from models.snippet import CustomizedSnippet, SnippetType +from models.workflow import Workflow +from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.snippet_service import SnippetService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "snippet_import_info:" +CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "snippet_check_dependencies:" +IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.0" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class SnippetImportInfo(BaseModel): + id: str + status: ImportStatus + snippet_id: str | None = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # If imported version is newer than current, always return PENDING + if imported_ver > current_ver: + return ImportStatus.PENDING + + # If imported version is older than current's major, return PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + + # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + + # If imported version equals or is older than current's micro, return COMPLETED + return ImportStatus.COMPLETED + + +class SnippetPendingData(BaseModel): + import_mode: str + yaml_content: str + snippet_id: str | None + + +class CheckDependenciesPendingData(BaseModel): + dependencies: list[PluginDependency] + snippet_id: str | None + + +class SnippetDslService: + def __init__(self, session: Session): + self._session = session + + def import_snippet( + self, + *, + account: Account, + import_mode: str, + yaml_content: str | None = None, + yaml_url: str | None = None, + snippet_id: str | None = None, + name: str | None = None, + description: str | None = None, + ) -> SnippetImportInfo: + """Import a snippet from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content: str = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + parsed_url = urlparse(yaml_url) + if parsed_url.scheme not in ["http", "https"]: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid URL scheme, only http and https are allowed", + ) + response = ssrf_proxy.get(yaml_url, timeout=(10, 30)) + if response.status_code != 200: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Failed to fetch YAML from URL: {response.status_code}", + ) + content = response.text + if len(content) > DSL_MAX_SIZE: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", + ) + except Exception as e: + logger.exception("Failed to fetch YAML from URL") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Failed to fetch YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + if len(content) > DSL_MAX_SIZE: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"YAML content size exceeds maximum limit of {DSL_MAX_SIZE} bytes", + ) + + try: + # Parse YAML + data = yaml.safe_load(content) + if not isinstance(data, dict): + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: expected a dictionary", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + + # Strictly validate kind field + kind = data.get("kind") + if not kind: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing 'kind' field in DSL. Expected 'kind: snippet'.", + ) + if kind != "snippet": + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid DSL kind: expected 'snippet', got '{kind}'. This DSL is for {kind}, not snippet.", + ) + + imported_version = data.get("version", "0.1.0") + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") + status = _check_version_compatibility(imported_version) + + # Extract snippet data + snippet_data = data.get("snippet") + if not snippet_data: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Missing snippet data in YAML content", + ) + + # If snippet_id is provided, check if it exists + snippet = None + if snippet_id: + stmt = select(CustomizedSnippet).where( + CustomizedSnippet.id == snippet_id, + CustomizedSnippet.tenant_id == account.current_tenant_id, + ) + snippet = self._session.scalar(stmt) + + if not snippet: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Snippet not found", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = SnippetPendingData( + import_mode=import_mode, + yaml_content=content, + snippet_id=snippet_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return SnippetImportInfo( + id=import_id, + status=status, + snippet_id=snippet_id, + imported_dsl_version=imported_version, + ) + + # Extract dependencies + dependencies = data.get("dependencies", []) + check_dependencies_pending_data = None + if dependencies: + check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] + + # Create or update snippet + snippet = self._create_or_update_snippet( + snippet=snippet, + data=data, + account=account, + name=name, + description=description, + dependencies=check_dependencies_pending_data, + ) + + return SnippetImportInfo( + id=import_id, + status=status, + snippet_id=snippet.id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import snippet") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> SnippetImportInfo: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + + pending_data_str = pending_data.decode("utf-8") if isinstance(pending_data, bytes) else pending_data + pending = SnippetPendingData.model_validate_json(pending_data_str) + + # Re-import with the pending data + return self.import_snippet( + account=account, + import_mode=pending.import_mode, + yaml_content=pending.yaml_content, + snippet_id=pending.snippet_id, + ) + + except Exception as e: + logger.exception("Failed to confirm import") + return SnippetImportInfo( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def check_dependencies(self, snippet: CustomizedSnippet) -> CheckDependenciesResult: + """ + Check dependencies for a snippet + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + return CheckDependenciesResult(leaked_dependencies=[]) + + dependencies = self._extract_dependencies_from_workflow(workflow) + leaked_dependencies = DependenciesAnalysisService.generate_dependencies( + tenant_id=snippet.tenant_id, dependencies=dependencies + ) + + return CheckDependenciesResult(leaked_dependencies=leaked_dependencies) + + def _create_or_update_snippet( + self, + *, + snippet: CustomizedSnippet | None, + data: dict, + account: Account, + name: str | None = None, + description: str | None = None, + dependencies: list[PluginDependency] | None = None, + ) -> CustomizedSnippet: + """ + Create or update snippet from DSL data + """ + snippet_data = data.get("snippet", {}) + workflow_data = data.get("workflow", {}) + + # Extract snippet info + snippet_name = name or snippet_data.get("name") or "Untitled Snippet" + snippet_description = description or snippet_data.get("description") or "" + snippet_type_str = snippet_data.get("type", "node") + try: + snippet_type = SnippetType(snippet_type_str) + except ValueError: + snippet_type = SnippetType.NODE + + icon_info = snippet_data.get("icon_info", {}) + input_fields = snippet_data.get("input_fields", []) + + # Create or update snippet + if snippet: + # Update existing snippet + snippet.name = snippet_name + snippet.description = snippet_description + snippet.type = snippet_type.value + snippet.icon_info = icon_info or None + snippet.input_fields = json.dumps(input_fields) if input_fields else None + snippet.updated_by = account.id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + else: + # Create new snippet + snippet = CustomizedSnippet( + tenant_id=account.current_tenant_id, + name=snippet_name, + description=snippet_description, + type=snippet_type.value, + icon_info=icon_info or None, + input_fields=json.dumps(input_fields) if input_fields else None, + created_by=account.id, + ) + self._session.add(snippet) + self._session.flush() + + # Create or update draft workflow + if workflow_data: + graph = workflow_data.get("graph", {}) + environment_variables_list = workflow_data.get("environment_variables", []) + conversation_variables_list = workflow_data.get("conversation_variables", []) + + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + snippet_service = SnippetService() + # Get existing workflow hash if exists + existing_workflow = snippet_service.get_draft_workflow(snippet=snippet) + unique_hash = existing_workflow.unique_hash if existing_workflow else None + + snippet_service.sync_draft_workflow( + snippet=snippet, + graph=graph, + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + input_variables=input_fields, + ) + + self._session.commit() + return snippet + + def export_snippet_dsl(self, snippet: CustomizedSnippet, include_secret: bool = False) -> str: + """ + Export snippet as DSL + :param snippet: CustomizedSnippet instance + :param include_secret: Whether include secret variable + :return: YAML string + """ + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + icon_info = snippet.icon_info or {} + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "snippet", + "snippet": { + "name": snippet.name, + "description": snippet.description or "", + "type": snippet.type, + "icon_info": icon_info, + "input_fields": snippet.input_fields_list, + }, + } + + self._append_workflow_export_data(export_data=export_data, snippet=snippet, workflow=workflow, + include_secret=include_secret) + + return yaml.dump(export_data, allow_unicode=True) # type: ignore + + def _append_workflow_export_data( + self, *, export_data: dict, snippet: CustomizedSnippet, workflow: Workflow, include_secret: bool + ) -> None: + """ + Append workflow export data + """ + workflow_dict = workflow.to_dict(include_secret=include_secret) + # Filter workspace related data from nodes + for node in workflow_dict.get("graph", {}).get("nodes", []): + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == NodeType.KNOWLEDGE_RETRIEVAL: + dataset_ids = node_data.get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + self._encrypt_dataset_id(dataset_id=dataset_id, tenant_id=snippet.tenant_id) + for dataset_id in dataset_ids + ] + # filter credential id from tool node + if not include_secret and data_type == NodeType.TOOL: + node_data.pop("credential_id", None) + # filter credential id from agent node + if not include_secret and data_type == NodeType.AGENT: + for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): + tool.pop("credential_id", None) + + export_data["workflow"] = workflow_dict + dependencies = self._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=snippet.tenant_id, dependencies=dependencies + ) + ] + + def _encrypt_dataset_id(self, *, dataset_id: str, tenant_id: str) -> str: + """ + Encrypt dataset ID for export + """ + # For now, just return the dataset_id as-is + # In the future, we might want to encrypt it + return dataset_id + + def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = self._extract_dependencies_from_workflow_graph(graph) + return dependencies + + def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]: + """ + Extract dependencies from workflow graph + :param graph: Workflow graph + :return: dependencies list format like ["langgenius/google"] + """ + dependencies = [] + for node in graph.get("nodes", []): + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == NodeType.TOOL: + tool_config = node_data.get("tool_configurations", {}) + provider_type = tool_config.get("provider_type") + provider_name = tool_config.get("provider") + if provider_type and provider_name: + dependencies.append(f"{provider_name}/{provider_name}") + elif data_type == NodeType.AGENT: + agent_parameters = node_data.get("agent_parameters", {}) + tools = agent_parameters.get("tools", {}).get("value", []) + for tool in tools: + provider_type = tool.get("provider_type") + provider_name = tool.get("provider") + if provider_type and provider_name: + dependencies.append(f"{provider_name}/{provider_name}") + + return dependencies