diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index fdc9aabc83..a244cbc299 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -115,6 +115,12 @@ from .explore import ( trial, ) +# Import evaluation controllers +from .evaluation import evaluation + +# Import snippet controllers +from .snippets import snippet_workflow + # Import tag controllers from .tag import tags @@ -128,6 +134,7 @@ from .workspace import ( model_providers, models, plugin, + snippets, tool_providers, trigger_providers, workspace, @@ -165,6 +172,7 @@ __all__ = [ "datasource_content_preview", "email_register", "endpoint", + "evaluation", "extension", "external", "feature", @@ -197,6 +205,8 @@ __all__ = [ "saved_message", "setup", "site", + "snippet_workflow", + "snippets", "spec", "statistic", "tags", diff --git a/api/controllers/console/evaluation/__init__.py b/api/controllers/console/evaluation/__init__.py new file mode 100644 index 0000000000..65c6eacd84 --- /dev/null +++ b/api/controllers/console/evaluation/__init__.py @@ -0,0 +1 @@ +# Evaluation controller module diff --git a/api/controllers/console/evaluation/evaluation.py b/api/controllers/console/evaluation/evaluation.py new file mode 100644 index 0000000000..d882afaf7c --- /dev/null +++ b/api/controllers/console/evaluation/evaluation.py @@ -0,0 +1,288 @@ +import logging +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar, Union + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field +from werkzeug.exceptions import NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from extensions.ext_database import db +from libs.helper import TimestampField +from libs.login import current_account_with_tenant, login_required +from models import App +from models.snippet import CustomizedSnippet + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + +# Valid evaluation target types +EVALUATE_TARGET_TYPES = {"app", "snippets"} + + +class VersionQuery(BaseModel): + """Query parameters for version endpoint.""" + + version: str + + +register_schema_models( + console_ns, + VersionQuery, +) + + +# Response field definitions +file_info_fields = { + "id": fields.String, + "name": fields.String, +} + +evaluation_log_fields = { + "created_at": TimestampField, + "created_by": fields.String, + "test_file": fields.Nested( + console_ns.model( + "EvaluationTestFile", + file_info_fields, + ) + ), + "result_file": fields.Nested( + console_ns.model( + "EvaluationResultFile", + file_info_fields, + ), + allow_null=True, + ), + "version": fields.String, +} + +evaluation_log_list_model = console_ns.model( + "EvaluationLogList", + { + "data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))), + }, +) + +customized_matrix_fields = { + "evaluation_workflow_id": fields.String, + "input_fields": fields.Raw, + "output_fields": fields.Raw, +} + +condition_fields = { + "name": fields.List(fields.String), + "comparison_operator": fields.String, + "value": fields.String, +} + +judgement_conditions_fields = { + "logical_operator": fields.String, + "conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))), +} + +evaluation_detail_fields = { + "evaluation_model": fields.String, + "evaluation_model_provider": fields.String, + "customized_matrix": fields.Nested( + console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields), + allow_null=True, + ), + "judgement_conditions": fields.Nested( + console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields), + allow_null=True, + ), +} + +evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields) + + +def get_evaluation_target(view_func: Callable[P, R]): + """ + Decorator to resolve polymorphic evaluation target (app or snippet). + + Validates the target_type parameter and fetches the corresponding + model (App or CustomizedSnippet) with tenant isolation. + """ + + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + target_type = kwargs.get("evaluate_target_type") + target_id = kwargs.get("evaluate_target_id") + + if target_type not in EVALUATE_TARGET_TYPES: + raise NotFound(f"Invalid evaluation target type: {target_type}") + + _, current_tenant_id = current_account_with_tenant() + + target_id = str(target_id) + + # Remove path parameters + del kwargs["evaluate_target_type"] + del kwargs["evaluate_target_id"] + + target: Union[App, CustomizedSnippet] | None = None + + if target_type == "app": + target = ( + db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first() + ) + elif target_type == "snippets": + target = ( + db.session.query(CustomizedSnippet) + .where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id) + .first() + ) + + if not target: + raise NotFound(f"{str(target_type)} not found") + + kwargs["target"] = target + kwargs["target_type"] = target_type + + return view_func(*args, **kwargs) + + return decorated_view + + +@console_ns.route("///dataset-template/download") +class EvaluationDatasetTemplateDownloadApi(Resource): + @console_ns.doc("download_evaluation_dataset_template") + @console_ns.response(200, "Template download URL generated successfully") + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + @edit_permission_required + def post(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Download evaluation dataset template. + + Generates a download URL for the evaluation dataset template + based on the target type (app or snippets). + """ + # TODO: Implement actual template generation logic + # This is a placeholder implementation + return { + "download_url": f"/api/evaluation/{target_type}/{target.id}/template.csv", + } + + +@console_ns.route("///evaluation") +class EvaluationDetailApi(Resource): + @console_ns.doc("get_evaluation_detail") + @console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model) + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get evaluation details for the target. + + Returns evaluation configuration including model settings, + customized matrix, and judgement conditions. + """ + # TODO: Implement actual evaluation detail retrieval + # This is a placeholder implementation + return { + "evaluation_model": None, + "evaluation_model_provider": None, + "customized_matrix": None, + "judgement_conditions": None, + } + + +@console_ns.route("///evaluation/logs") +class EvaluationLogsApi(Resource): + @console_ns.doc("get_evaluation_logs") + @console_ns.response(200, "Evaluation logs retrieved successfully", evaluation_log_list_model) + @console_ns.response(404, "Target not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get offline evaluation logs for the target. + + Returns a list of evaluation runs with test files, + result files, and version information. + """ + # TODO: Implement actual evaluation logs retrieval + # This is a placeholder implementation + return { + "data": [], + } + + +@console_ns.route("///evaluation/files/") +class EvaluationFileDownloadApi(Resource): + @console_ns.doc("download_evaluation_file") + @console_ns.response(200, "File download URL generated successfully") + @console_ns.response(404, "Target or file not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str): + """ + Download evaluation test file or result file. + + Returns file information and download URL for the specified file. + """ + file_id = str(file_id) + + # TODO: Implement actual file download logic + # This is a placeholder implementation + return { + "created_at": None, + "created_by": None, + "test_file": None, + "result_file": None, + "version": None, + } + + +@console_ns.route("///evaluation/version") +class EvaluationVersionApi(Resource): + @console_ns.doc("get_evaluation_version_detail") + @console_ns.expect(console_ns.models.get(VersionQuery.__name__)) + @console_ns.response(200, "Version details retrieved successfully") + @console_ns.response(404, "Target or version not found") + @setup_required + @login_required + @account_initialization_required + @get_evaluation_target + def get(self, target: Union[App, CustomizedSnippet], target_type: str): + """ + Get evaluation target version details. + + Returns the workflow graph for the specified version. + """ + version = request.args.get("version") + + if not version: + return {"message": "version parameter is required"}, 400 + + # TODO: Implement actual version detail retrieval + # For now, return the current graph if available + graph = {} + if target_type == "snippets" and isinstance(target, CustomizedSnippet): + graph = target.graph_dict + + return { + "graph": graph, + } diff --git a/api/controllers/console/snippets/payloads.py b/api/controllers/console/snippets/payloads.py new file mode 100644 index 0000000000..06bd4d00fe --- /dev/null +++ b/api/controllers/console/snippets/payloads.py @@ -0,0 +1,75 @@ +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class SnippetListQuery(BaseModel): + """Query parameters for listing snippets.""" + + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=20, ge=1, le=100) + keyword: str | None = None + + +class IconInfo(BaseModel): + """Icon information model.""" + + icon: str | None = None + icon_type: Literal["emoji", "image"] | None = None + icon_background: str | None = None + icon_url: str | None = None + + +class InputFieldDefinition(BaseModel): + """Input field definition for snippet parameters.""" + + default: str | None = None + hint: bool | None = None + label: str | None = None + max_length: int | None = None + options: list[str] | None = None + placeholder: str | None = None + required: bool | None = None + type: str | None = None # e.g., "text-input" + + +class CreateSnippetPayload(BaseModel): + """Payload for creating a new snippet.""" + + name: str = Field(..., min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=2000) + type: Literal["node", "group"] = "node" + icon_info: IconInfo | None = None + graph: dict[str, Any] | None = None + input_fields: list[InputFieldDefinition] | None = Field(default_factory=list) + + +class UpdateSnippetPayload(BaseModel): + """Payload for updating a snippet.""" + + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=2000) + icon_info: IconInfo | None = None + + +class SnippetDraftSyncPayload(BaseModel): + """Payload for syncing snippet draft workflow.""" + + graph: dict[str, Any] + hash: str | None = None + environment_variables: list[dict[str, Any]] | None = None + conversation_variables: list[dict[str, Any]] | None = None + input_variables: list[dict[str, Any]] | None = None + + +class WorkflowRunQuery(BaseModel): + """Query parameters for workflow runs.""" + + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class PublishWorkflowPayload(BaseModel): + """Payload for publishing snippet workflow.""" + + knowledge_base_setting: dict[str, Any] | None = None diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py new file mode 100644 index 0000000000..eef04bd740 --- /dev/null +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -0,0 +1,306 @@ +import logging +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar + +from flask import request +from flask_restx import Resource, marshal_with +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.app.workflow import workflow_model +from controllers.console.app.workflow_run import ( + workflow_run_detail_model, + workflow_run_node_execution_list_model, + workflow_run_pagination_model, +) +from controllers.console.snippets.payloads import ( + PublishWorkflowPayload, + SnippetDraftSyncPayload, + WorkflowRunQuery, +) +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from extensions.ext_database import db +from factories import variable_factory +from libs.helper import TimestampField +from libs.login import current_account_with_tenant, login_required +from models.snippet import CustomizedSnippet +from services.errors.app import WorkflowHashNotEqualError +from services.snippet_service import SnippetService + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + +# Register Pydantic models with Swagger +register_schema_models( + console_ns, + SnippetDraftSyncPayload, + WorkflowRunQuery, + PublishWorkflowPayload, +) + + +class SnippetNotFoundError(Exception): + """Snippet not found error.""" + + pass + + +def get_snippet(view_func: Callable[P, R]): + """Decorator to fetch and validate snippet access.""" + + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + if not kwargs.get("snippet_id"): + raise ValueError("missing snippet_id in path parameters") + + _, current_tenant_id = current_account_with_tenant() + + snippet_id = str(kwargs.get("snippet_id")) + del kwargs["snippet_id"] + + snippet = SnippetService.get_snippet_by_id( + snippet_id=snippet_id, + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + kwargs["snippet"] = snippet + + return view_func(*args, **kwargs) + + return decorated_view + + +@console_ns.route("/snippets//workflows/draft") +class SnippetDraftWorkflowApi(Resource): + @console_ns.doc("get_snippet_draft_workflow") + @console_ns.response(200, "Draft workflow retrieved successfully", workflow_model) + @console_ns.response(404, "Snippet or draft workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + @marshal_with(workflow_model) + def get(self, snippet: CustomizedSnippet): + """Get draft workflow for snippet.""" + snippet_service = SnippetService() + workflow = snippet_service.get_draft_workflow(snippet=snippet) + + if not workflow: + raise DraftWorkflowNotExist() + + return workflow + + @console_ns.doc("sync_snippet_draft_workflow") + @console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__)) + @console_ns.response(200, "Draft workflow synced successfully") + @console_ns.response(400, "Hash mismatch") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet): + """Sync draft workflow for snippet.""" + current_user, _ = current_account_with_tenant() + + payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {}) + + try: + environment_variables_list = payload.environment_variables or [] + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = payload.conversation_variables or [] + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + snippet_service = SnippetService() + workflow = snippet_service.sync_draft_workflow( + snippet=snippet, + graph=payload.graph, + unique_hash=payload.hash, + account=current_user, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + input_variables=payload.input_variables, + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + +@console_ns.route("/snippets//workflows/draft/config") +class SnippetDraftConfigApi(Resource): + @console_ns.doc("get_snippet_draft_config") + @console_ns.response(200, "Draft config retrieved successfully") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def get(self, snippet: CustomizedSnippet): + """Get snippet draft workflow configuration limits.""" + return { + "parallel_depth_limit": 3, + } + + +@console_ns.route("/snippets//workflows/publish") +class SnippetPublishedWorkflowApi(Resource): + @console_ns.doc("get_snippet_published_workflow") + @console_ns.response(200, "Published workflow retrieved successfully", workflow_model) + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + @marshal_with(workflow_model) + def get(self, snippet: CustomizedSnippet): + """Get published workflow for snippet.""" + if not snippet.is_published: + return None + + snippet_service = SnippetService() + workflow = snippet_service.get_published_workflow(snippet=snippet) + + return workflow + + @console_ns.doc("publish_snippet_workflow") + @console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__)) + @console_ns.response(200, "Workflow published successfully") + @console_ns.response(400, "No draft workflow found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def post(self, snippet: CustomizedSnippet): + """Publish snippet workflow.""" + current_user, _ = current_account_with_tenant() + snippet_service = SnippetService() + + with Session(db.engine) as session: + snippet = session.merge(snippet) + try: + workflow = snippet_service.publish_workflow( + session=session, + snippet=snippet, + account=current_user, + ) + workflow_created_at = TimestampField().format(workflow.created_at) + session.commit() + except ValueError as e: + return {"message": str(e)}, 400 + + return { + "result": "success", + "created_at": workflow_created_at, + } + + +@console_ns.route("/snippets//workflows/default-workflow-block-configs") +class SnippetDefaultBlockConfigsApi(Resource): + @console_ns.doc("get_snippet_default_block_configs") + @console_ns.response(200, "Default block configs retrieved successfully") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @edit_permission_required + def get(self, snippet: CustomizedSnippet): + """Get default block configurations for snippet workflow.""" + snippet_service = SnippetService() + return snippet_service.get_default_block_configs() + + +@console_ns.route("/snippets//workflow-runs") +class SnippetWorkflowRunsApi(Resource): + @console_ns.doc("list_snippet_workflow_runs") + @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_pagination_model) + def get(self, snippet: CustomizedSnippet): + """List workflow runs for snippet.""" + query = WorkflowRunQuery.model_validate( + { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", type=int, default=20), + } + ) + args = { + "last_id": query.last_id, + "limit": query.limit, + } + + snippet_service = SnippetService() + result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args) + + return result + + +@console_ns.route("/snippets//workflow-runs/") +class SnippetWorkflowRunDetailApi(Resource): + @console_ns.doc("get_snippet_workflow_run_detail") + @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model) + @console_ns.response(404, "Workflow run not found") + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_detail_model) + def get(self, snippet: CustomizedSnippet, run_id): + """Get workflow run detail for snippet.""" + run_id = str(run_id) + + snippet_service = SnippetService() + workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id) + + if not workflow_run: + raise NotFound("Workflow run not found") + + return workflow_run + + +@console_ns.route("/snippets//workflow-runs//node-executions") +class SnippetWorkflowRunNodeExecutionsApi(Resource): + @console_ns.doc("list_snippet_workflow_run_node_executions") + @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model) + @setup_required + @login_required + @account_initialization_required + @get_snippet + @marshal_with(workflow_run_node_execution_list_model) + def get(self, snippet: CustomizedSnippet, run_id): + """List node executions for a workflow run.""" + run_id = str(run_id) + + snippet_service = SnippetService() + node_executions = snippet_service.get_snippet_workflow_run_node_executions( + snippet=snippet, + run_id=run_id, + ) + + return {"data": node_executions} diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py new file mode 100644 index 0000000000..6a7992f889 --- /dev/null +++ b/api/controllers/console/workspace/snippets.py @@ -0,0 +1,202 @@ +import logging + +from flask import request +from flask_restx import Resource, marshal, marshal_with +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.snippets.payloads import ( + CreateSnippetPayload, + SnippetListQuery, + UpdateSnippetPayload, +) +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, +) +from extensions.ext_database import db +from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields +from libs.login import current_account_with_tenant, login_required +from models.snippet import SnippetType +from services.snippet_service import SnippetService + +logger = logging.getLogger(__name__) + +# Register Pydantic models with Swagger +register_schema_models( + console_ns, + SnippetListQuery, + CreateSnippetPayload, + UpdateSnippetPayload, +) + +# Create namespace models for marshaling +snippet_model = console_ns.model("Snippet", snippet_fields) +snippet_list_model = console_ns.model("SnippetList", snippet_list_fields) +snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields) + + +@console_ns.route("/workspaces/current/customized-snippets") +class CustomizedSnippetsApi(Resource): + @console_ns.doc("list_customized_snippets") + @console_ns.expect(console_ns.models.get(SnippetListQuery.__name__)) + @console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model) + @setup_required + @login_required + @account_initialization_required + def get(self): + """List customized snippets with pagination and search.""" + _, current_tenant_id = current_account_with_tenant() + + query_params = request.args.to_dict() + query = SnippetListQuery.model_validate(query_params) + + snippets, total, has_more = SnippetService.get_snippets( + tenant_id=current_tenant_id, + page=query.page, + limit=query.limit, + keyword=query.keyword, + ) + + return { + "data": marshal(snippets, snippet_list_fields), + "page": query.page, + "limit": query.limit, + "total": total, + "has_more": has_more, + }, 200 + + @console_ns.doc("create_customized_snippet") + @console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__)) + @console_ns.response(201, "Snippet created successfully", snippet_model) + @console_ns.response(400, "Invalid request or name already exists") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def post(self): + """Create a new customized snippet.""" + current_user, current_tenant_id = current_account_with_tenant() + + payload = CreateSnippetPayload.model_validate(console_ns.payload or {}) + + try: + snippet_type = SnippetType(payload.type) + except ValueError: + snippet_type = SnippetType.NODE + + try: + snippet = SnippetService.create_snippet( + tenant_id=current_tenant_id, + name=payload.name, + description=payload.description, + snippet_type=snippet_type, + icon_info=payload.icon_info.model_dump() if payload.icon_info else None, + graph=payload.graph, + input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None, + account=current_user, + ) + except ValueError as e: + return {"message": str(e)}, 400 + + return marshal(snippet, snippet_fields), 201 + + +@console_ns.route("/workspaces/current/customized-snippets/") +class CustomizedSnippetDetailApi(Resource): + @console_ns.doc("get_customized_snippet") + @console_ns.response(200, "Snippet retrieved successfully", snippet_model) + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + def get(self, snippet_id: str): + """Get customized snippet details.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + return marshal(snippet, snippet_fields), 200 + + @console_ns.doc("update_customized_snippet") + @console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__)) + @console_ns.response(200, "Snippet updated successfully", snippet_model) + @console_ns.response(400, "Invalid request or name already exists") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def patch(self, snippet_id: str): + """Update customized snippet.""" + current_user, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + payload = UpdateSnippetPayload.model_validate(console_ns.payload or {}) + update_data = payload.model_dump(exclude_unset=True) + + if "icon_info" in update_data and update_data["icon_info"] is not None: + update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None + + if not update_data: + return {"message": "No valid fields to update"}, 400 + + try: + with Session(db.engine, expire_on_commit=False) as session: + snippet = session.merge(snippet) + snippet = SnippetService.update_snippet( + session=session, + snippet=snippet, + account_id=current_user.id, + data=update_data, + ) + session.commit() + except ValueError as e: + return {"message": str(e)}, 400 + + return marshal(snippet, snippet_fields), 200 + + @console_ns.doc("delete_customized_snippet") + @console_ns.response(204, "Snippet deleted successfully") + @console_ns.response(404, "Snippet not found") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + def delete(self, snippet_id: str): + """Delete customized snippet.""" + _, current_tenant_id = current_account_with_tenant() + + snippet = SnippetService.get_snippet_by_id( + snippet_id=str(snippet_id), + tenant_id=current_tenant_id, + ) + + if not snippet: + raise NotFound("Snippet not found") + + with Session(db.engine) as session: + snippet = session.merge(snippet) + SnippetService.delete_snippet( + session=session, + snippet=snippet, + ) + session.commit() + + return "", 204 diff --git a/api/fields/snippet_fields.py b/api/fields/snippet_fields.py new file mode 100644 index 0000000000..2562d7cff0 --- /dev/null +++ b/api/fields/snippet_fields.py @@ -0,0 +1,45 @@ +from flask_restx import fields + +from fields.member_fields import simple_account_fields +from libs.helper import TimestampField + +# Snippet list item fields (lightweight for list display) +snippet_list_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "version": fields.Integer, + "use_count": fields.Integer, + "is_published": fields.Boolean, + "icon_info": fields.Raw, + "created_at": TimestampField, + "updated_at": TimestampField, +} + +# Full snippet fields (includes creator info and graph data) +snippet_fields = { + "id": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "version": fields.Integer, + "use_count": fields.Integer, + "is_published": fields.Boolean, + "icon_info": fields.Raw, + "graph": fields.Raw(attribute="graph_dict"), + "input_fields": fields.Raw(attribute="input_fields_list"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, +} + +# Pagination response fields +snippet_pagination_fields = { + "data": fields.List(fields.Nested(snippet_list_fields)), + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, +} diff --git a/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py b/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py new file mode 100644 index 0000000000..5d487b121b --- /dev/null +++ b/api/migrations/versions/2026_01_29_1200-1c05e80d2380_add_customized_snippets_table.py @@ -0,0 +1,83 @@ +"""add_customized_snippets_table + +Revision ID: 1c05e80d2380 +Revises: 788d3099ae3a +Create Date: 2026-01-29 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +import models as models + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + + +# revision identifiers, used by Alembic. +revision = "1c05e80d2380" +down_revision = "788d3099ae3a" +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + + if _is_pg(conn): + op.create_table( + "customized_snippets", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False), + sa.Column("workflow_id", models.types.StringUUID(), nullable=True), + sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False), + sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False), + sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("graph", sa.Text(), nullable=True), + sa.Column("input_fields", sa.Text(), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=True), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + else: + op.create_table( + "customized_snippets", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", models.types.LongText(), nullable=True), + sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False), + sa.Column("workflow_id", models.types.StringUUID(), nullable=True), + sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False), + sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False), + sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True), + sa.Column("graph", models.types.LongText(), nullable=True), + sa.Column("input_fields", models.types.LongText(), nullable=True), + sa.Column("created_by", models.types.StringUUID(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=True), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + + with op.batch_alter_table("customized_snippets", schema=None) as batch_op: + batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False) + + +def downgrade(): + with op.batch_alter_table("customized_snippets", schema=None) as batch_op: + batch_op.drop_index("customized_snippet_tenant_idx") + + op.drop_table("customized_snippets") diff --git a/api/models/__init__.py b/api/models/__init__.py index 74b33130ef..c11df11a11 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -79,6 +79,7 @@ from .provider import ( TenantDefaultModel, TenantPreferredModelProvider, ) +from .snippet import CustomizedSnippet, SnippetType from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from .task import CeleryTask, CeleryTaskSet from .tools import ( @@ -138,6 +139,7 @@ __all__ = [ "Conversation", "ConversationVariable", "CreatorUserRole", + "CustomizedSnippet", "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", @@ -179,6 +181,7 @@ __all__ = [ "RecommendedApp", "SavedMessage", "Site", + "SnippetType", "Tag", "TagBinding", "Tenant", diff --git a/api/models/snippet.py b/api/models/snippet.py new file mode 100644 index 0000000000..977a6c868d --- /dev/null +++ b/api/models/snippet.py @@ -0,0 +1,96 @@ +import json +from datetime import datetime +from enum import StrEnum +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func +from sqlalchemy.orm import Mapped, mapped_column + +from libs.uuid_utils import uuidv7 + +from .account import Account +from .base import Base +from .engine import db +from .types import AdjustedJSON, LongText, StringUUID + + +class SnippetType(StrEnum): + """Snippet Type Enum""" + + NODE = "node" + GROUP = "group" + + +class CustomizedSnippet(Base): + """ + Customized Snippet Model + + Stores reusable workflow components (nodes or node groups) that can be + shared across applications within a workspace. + """ + + __tablename__ = "customized_snippets" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"), + sa.Index("customized_snippet_tenant_idx", "tenant_id"), + sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(LongText, nullable=True) + type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'")) + + # Workflow reference for published version + workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + # State flags + is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1")) + use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + + # Visual customization + icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True) + + # Snippet configuration (stored as JSON text) + graph: Mapped[str | None] = mapped_column(LongText, nullable=True) + input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True) + + # Audit fields + created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def graph_dict(self) -> dict[str, Any]: + """Parse graph JSON to dict.""" + return json.loads(self.graph) if self.graph else {} + + @property + def input_fields_list(self) -> list[dict[str, Any]]: + """Parse input_fields JSON to list.""" + return json.loads(self.input_fields) if self.input_fields else [] + + @property + def created_by_account(self) -> Account | None: + """Get the account that created this snippet.""" + if self.created_by: + return db.session.get(Account, self.created_by) + return None + + @property + def updated_by_account(self) -> Account | None: + """Get the account that last updated this snippet.""" + if self.updated_by: + return db.session.get(Account, self.updated_by) + return None + + @property + def version_str(self) -> str: + """Get version as string for API response.""" + return str(self.version) diff --git a/api/models/workflow.py b/api/models/workflow.py index df83228c2a..708218801a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -65,6 +65,7 @@ class WorkflowType(StrEnum): WORKFLOW = "workflow" CHAT = "chat" RAG_PIPELINE = "rag-pipeline" + SNIPPET = "snippet" @classmethod def value_of(cls, value: str) -> "WorkflowType": diff --git a/api/services/snippet_service.py b/api/services/snippet_service.py new file mode 100644 index 0000000000..3664ed2617 --- /dev/null +++ b/api/services/snippet_service.py @@ -0,0 +1,542 @@ +import json +import logging +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import func, select +from sqlalchemy.orm import Session, sessionmaker + +from core.variables.variables import VariableBase +from core.workflow.enums import NodeType +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING +from extensions.ext_database import db +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models import Account +from models.enums import WorkflowRunTriggeredFrom +from models.snippet import CustomizedSnippet, SnippetType +from models.workflow import ( + Workflow, + WorkflowNodeExecutionModel, + WorkflowRun, + WorkflowType, +) +from repositories.factory import DifyAPIRepositoryFactory +from services.errors.app import WorkflowHashNotEqualError + +logger = logging.getLogger(__name__) + + +class SnippetService: + """Service for managing customized snippets.""" + + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize SnippetService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + # --- CRUD Operations --- + + @staticmethod + def get_snippets( + *, + tenant_id: str, + page: int = 1, + limit: int = 20, + keyword: str | None = None, + ) -> tuple[Sequence[CustomizedSnippet], int, bool]: + """ + Get paginated list of snippets with optional search. + + :param tenant_id: Tenant ID + :param page: Page number (1-indexed) + :param limit: Number of items per page + :param keyword: Optional search keyword for name/description + :return: Tuple of (snippets list, total count, has_more flag) + """ + stmt = ( + select(CustomizedSnippet) + .where(CustomizedSnippet.tenant_id == tenant_id) + .order_by(CustomizedSnippet.created_at.desc()) + ) + + if keyword: + stmt = stmt.where( + CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%") + ) + + # Get total count + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = db.session.scalar(count_stmt) or 0 + + # Apply pagination + stmt = stmt.limit(limit + 1).offset((page - 1) * limit) + snippets = list(db.session.scalars(stmt).all()) + + has_more = len(snippets) > limit + if has_more: + snippets = snippets[:-1] + + return snippets, total, has_more + + @staticmethod + def get_snippet_by_id( + *, + snippet_id: str, + tenant_id: str, + ) -> CustomizedSnippet | None: + """ + Get snippet by ID with tenant isolation. + + :param snippet_id: Snippet ID + :param tenant_id: Tenant ID + :return: CustomizedSnippet or None + """ + return ( + db.session.query(CustomizedSnippet) + .where( + CustomizedSnippet.id == snippet_id, + CustomizedSnippet.tenant_id == tenant_id, + ) + .first() + ) + + @staticmethod + def create_snippet( + *, + tenant_id: str, + name: str, + description: str | None, + snippet_type: SnippetType, + icon_info: dict | None, + graph: dict | None, + input_fields: list[dict] | None, + account: Account, + ) -> CustomizedSnippet: + """ + Create a new snippet. + + :param tenant_id: Tenant ID + :param name: Snippet name (must be unique per tenant) + :param description: Snippet description + :param snippet_type: Type of snippet (node or group) + :param icon_info: Icon information + :param graph: Workflow graph structure + :param input_fields: Input field definitions + :param account: Creator account + :return: Created CustomizedSnippet + :raises ValueError: If name already exists + """ + # Check if name already exists for this tenant + existing = ( + db.session.query(CustomizedSnippet) + .where( + CustomizedSnippet.tenant_id == tenant_id, + CustomizedSnippet.name == name, + ) + .first() + ) + if existing: + raise ValueError(f"Snippet with name '{name}' already exists") + + snippet = CustomizedSnippet( + tenant_id=tenant_id, + name=name, + description=description or "", + type=snippet_type.value, + icon_info=icon_info, + graph=json.dumps(graph) if graph else None, + input_fields=json.dumps(input_fields) if input_fields else None, + created_by=account.id, + ) + + db.session.add(snippet) + db.session.commit() + + return snippet + + @staticmethod + def update_snippet( + *, + session: Session, + snippet: CustomizedSnippet, + account_id: str, + data: dict, + ) -> CustomizedSnippet: + """ + Update snippet attributes. + + :param session: Database session + :param snippet: Snippet to update + :param account_id: ID of account making the update + :param data: Dictionary of fields to update + :return: Updated CustomizedSnippet + """ + if "name" in data: + # Check if new name already exists for this tenant + existing = ( + session.query(CustomizedSnippet) + .where( + CustomizedSnippet.tenant_id == snippet.tenant_id, + CustomizedSnippet.name == data["name"], + CustomizedSnippet.id != snippet.id, + ) + .first() + ) + if existing: + raise ValueError(f"Snippet with name '{data['name']}' already exists") + snippet.name = data["name"] + + if "description" in data: + snippet.description = data["description"] + + if "icon_info" in data: + snippet.icon_info = data["icon_info"] + + snippet.updated_by = account_id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(snippet) + return snippet + + @staticmethod + def delete_snippet( + *, + session: Session, + snippet: CustomizedSnippet, + ) -> bool: + """ + Delete a snippet. + + :param session: Database session + :param snippet: Snippet to delete + :return: True if deleted successfully + """ + session.delete(snippet) + return True + + # --- Workflow Operations --- + + def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: + """ + Get draft workflow for snippet. + + :param snippet: CustomizedSnippet instance + :return: Draft Workflow or None + """ + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version == "draft", + ) + .first() + ) + return workflow + + def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None: + """ + Get published workflow for snippet. + + :param snippet: CustomizedSnippet instance + :return: Published Workflow or None + """ + if not snippet.workflow_id: + return None + + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.id == snippet.workflow_id, + ) + .first() + ) + return workflow + + def sync_draft_workflow( + self, + *, + snippet: CustomizedSnippet, + graph: dict, + unique_hash: str | None, + account: Account, + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], + input_variables: list[dict] | None = None, + ) -> Workflow: + """ + Sync draft workflow for snippet. + + :param snippet: CustomizedSnippet instance + :param graph: Workflow graph configuration + :param unique_hash: Hash for conflict detection + :param account: Account making the change + :param environment_variables: Environment variables + :param conversation_variables: Conversation variables + :param input_variables: Input variables for snippet + :return: Synced Workflow + :raises WorkflowHashNotEqualError: If hash mismatch + """ + workflow = self.get_draft_workflow(snippet=snippet) + + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + + # Create draft workflow if not found + if not workflow: + workflow = Workflow( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + features="{}", + type=WorkflowType.SNIPPET.value, + version="draft", + graph=json.dumps(graph), + created_by=account.id, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + db.session.add(workflow) + db.session.flush() + else: + # Update existing draft workflow + workflow.graph = json.dumps(graph) + workflow.updated_by = account.id + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.environment_variables = environment_variables + workflow.conversation_variables = conversation_variables + + # Update snippet's input_fields if provided + if input_variables is not None: + snippet.input_fields = json.dumps(input_variables) + snippet.updated_by = account.id + snippet.updated_at = datetime.now(UTC).replace(tzinfo=None) + + db.session.commit() + return workflow + + def publish_workflow( + self, + *, + session: Session, + snippet: CustomizedSnippet, + account: Account, + ) -> Workflow: + """ + Publish the draft workflow as a new version. + + :param session: Database session + :param snippet: CustomizedSnippet instance + :param account: Account making the change + :return: Published Workflow + :raises ValueError: If no draft workflow exists + """ + draft_workflow_stmt = select(Workflow).where( + Workflow.tenant_id == snippet.tenant_id, + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version == "draft", + ) + draft_workflow = session.scalar(draft_workflow_stmt) + if not draft_workflow: + raise ValueError("No valid workflow found.") + + # Create new published workflow + workflow = Workflow.new( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + type=draft_workflow.type, + version=str(datetime.now(UTC).replace(tzinfo=None)), + graph=draft_workflow.graph, + features=draft_workflow.features, + created_by=account.id, + environment_variables=draft_workflow.environment_variables, + conversation_variables=draft_workflow.conversation_variables, + marked_name="", + marked_comment="", + ) + session.add(workflow) + + # Update snippet version + snippet.version += 1 + snippet.is_published = True + snippet.workflow_id = workflow.id + snippet.updated_by = account.id + session.add(snippet) + + return workflow + + def get_all_published_workflows( + self, + *, + session: Session, + snippet: CustomizedSnippet, + page: int, + limit: int, + ) -> tuple[Sequence[Workflow], bool]: + """ + Get all published workflow versions for snippet. + + :param session: Database session + :param snippet: CustomizedSnippet instance + :param page: Page number + :param limit: Items per page + :return: Tuple of (workflows list, has_more flag) + """ + if not snippet.workflow_id: + return [], False + + stmt = ( + select(Workflow) + .where( + Workflow.app_id == snippet.id, + Workflow.type == WorkflowType.SNIPPET.value, + Workflow.version != "draft", + ) + .order_by(Workflow.version.desc()) + .limit(limit + 1) + .offset((page - 1) * limit) + ) + + workflows = list(session.scalars(stmt).all()) + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + + # --- Default Block Configs --- + + def get_default_block_configs(self) -> list[dict]: + """ + Get default block configurations for all node types. + + :return: List of default configurations + """ + default_block_configs: list[dict[str, Any]] = [] + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(dict(default_config)) + + return default_block_configs + + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + """ + Get default config for specific node type. + + :param node_type: Node type string + :param filters: Optional filters + :return: Default configuration or None + """ + node_type_enum = NodeType(node_type) + + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + return None + + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config + + # --- Workflow Run Operations --- + + def get_snippet_workflow_runs( + self, + *, + snippet: CustomizedSnippet, + args: dict, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs for snippet. + + :param snippet: CustomizedSnippet instance + :param args: Request arguments (last_id, limit) + :return: InfiniteScrollPagination result + """ + limit = int(args.get("limit", 20)) + last_id = args.get("last_id") + + triggered_from_values = [ + WorkflowRunTriggeredFrom.DEBUGGING, + ] + + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + triggered_from=triggered_from_values, + limit=limit, + last_id=last_id, + ) + + def get_snippet_workflow_run( + self, + *, + snippet: CustomizedSnippet, + run_id: str, + ) -> WorkflowRun | None: + """ + Get workflow run details. + + :param snippet: CustomizedSnippet instance + :param run_id: Workflow run ID + :return: WorkflowRun or None + """ + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + run_id=run_id, + ) + + def get_snippet_workflow_run_node_executions( + self, + *, + snippet: CustomizedSnippet, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get workflow run node execution list. + + :param snippet: CustomizedSnippet instance + :param run_id: Workflow run ID + :return: List of WorkflowNodeExecutionModel + """ + workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id) + if not workflow_run: + return [] + + node_executions = self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=snippet.tenant_id, + app_id=snippet.id, + workflow_run_id=workflow_run.id, + ) + + return node_executions + + # --- Use Count --- + + @staticmethod + def increment_use_count( + *, + session: Session, + snippet: CustomizedSnippet, + ) -> None: + """ + Increment the use_count when snippet is used. + + :param session: Database session + :param snippet: CustomizedSnippet instance + """ + snippet.use_count += 1 + session.add(snippet)