feat: Implement snippet_generate_service.py.

This commit is contained in:
FFXN
2026-03-06 14:28:08 +08:00
parent b88195c7d9
commit 9340ee8af4
5 changed files with 303 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ class EvaluationCategory(StrEnum):
RETRIEVAL = "knowledge_retrieval"
AGENT = "agent"
WORKFLOW = "workflow"
SNIPPET = "snippet"
RETRIEVAL_TEST = "retrieval_test"

View File

@@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, App, TenantAccountJoin
from models import Account, App, CustomizedSnippet, TenantAccountJoin
def get_service_account_for_app(session: Session, app_id: str) -> Account:
@@ -30,3 +30,31 @@ def get_service_account_for_app(session: Session, app_id: str) -> Account:
account.set_tenant_id(current_tenant.tenant_id)
return account
def get_service_account_for_snippet(session: Session, snippet_id: str) -> Account:
"""Get the creator account for a snippet with tenant context set up.
Mirrors :func:`get_service_account_for_app` but queries CustomizedSnippet.
"""
snippet = session.scalar(select(CustomizedSnippet).where(CustomizedSnippet.id == snippet_id))
if not snippet:
raise ValueError(f"Snippet with id {snippet_id} not found")
if not snippet.created_by:
raise ValueError(f"Snippet with id {snippet_id} has no creator")
account = session.scalar(select(Account).where(Account.id == snippet.created_by))
if not account:
raise ValueError(f"Creator account not found for snippet {snippet_id}")
current_tenant = (
session.query(TenantAccountJoin)
.filter_by(account_id=account.id, current=True)
.first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {account.id}")
account.set_tenant_id(current_tenant.tenant_id)
return account

View File

@@ -0,0 +1,228 @@
"""Runner for Snippet evaluation.
Executes a published Snippet workflow in non-streaming mode, collects the
actual outputs and per-node execution records, then delegates to the
evaluation instance for metric computation.
"""
import json
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import asc, select
from sqlalchemy.orm import Session
from core.evaluation.base_evaluation_instance import BaseEvaluationInstance
from core.evaluation.entities.evaluation_entity import (
EvaluationItemInput,
EvaluationItemResult,
)
from core.evaluation.runners.base_evaluation_runner import BaseEvaluationRunner
from models.snippet import CustomizedSnippet
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class SnippetEvaluationRunner(BaseEvaluationRunner):
"""Runner for snippet evaluation: executes a published Snippet workflow."""
def __init__(self, evaluation_instance: BaseEvaluationInstance, session: Session):
super().__init__(evaluation_instance, session)
def execute_target(
self,
tenant_id: str,
target_id: str,
target_type: str,
item: EvaluationItemInput,
) -> EvaluationItemResult:
"""Execute a published Snippet workflow and collect outputs.
Steps:
1. Delegate execution to ``SnippetGenerateService.run_published``.
2. Extract ``workflow_run_id`` from the blocking response.
3. Query ``workflow_node_executions`` by ``workflow_run_id`` to get
each node's inputs, outputs, status, elapsed_time, etc.
4. Return result with actual_output and node_executions metadata.
"""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.evaluation.runners import get_service_account_for_snippet
from services.snippet_generate_service import SnippetGenerateService
snippet = self.session.query(CustomizedSnippet).filter_by(id=target_id).first()
if not snippet:
raise ValueError(f"Snippet {target_id} not found")
if not snippet.is_published:
raise ValueError(f"Snippet {target_id} is not published")
service_account = get_service_account_for_snippet(self.session, target_id)
response = SnippetGenerateService.run_published(
snippet=snippet,
user=service_account,
args={"inputs": item.inputs},
invoke_from=InvokeFrom.SERVICE_API,
)
actual_output = self._extract_output(response)
# Retrieve per-node execution records from DB
workflow_run_id = self._extract_workflow_run_id(response)
node_executions = self._query_node_executions(
tenant_id=tenant_id,
app_id=target_id,
workflow_run_id=workflow_run_id,
) if workflow_run_id else []
return EvaluationItemResult(
index=item.index,
actual_output=actual_output,
metadata={
"workflow_run_id": workflow_run_id or "",
"node_executions": node_executions,
},
)
def evaluate_metrics(
self,
items: list[EvaluationItemInput],
results: list[EvaluationItemResult],
default_metrics: list[dict[str, Any]],
model_provider: str,
model_name: str,
tenant_id: str,
) -> list[EvaluationItemResult]:
"""Compute evaluation metrics for snippet outputs.
Snippets are essentially workflows, so we reuse evaluate_workflow from
the evaluation instance.
"""
result_by_index = {r.index: r for r in results}
merged_items = []
for item in items:
result = result_by_index.get(item.index)
context = []
if result and result.actual_output:
context.append(result.actual_output)
merged_items.append(
EvaluationItemInput(
index=item.index,
inputs=item.inputs,
expected_output=item.expected_output,
context=context + (item.context or []),
)
)
evaluated = self.evaluation_instance.evaluate_workflow(
merged_items, default_metrics, model_provider, model_name, tenant_id
)
# Merge metrics back preserving metadata from Phase 1
eval_by_index = {r.index: r for r in evaluated}
final_results = []
for result in results:
if result.index in eval_by_index:
eval_result = eval_by_index[result.index]
final_results.append(
EvaluationItemResult(
index=result.index,
actual_output=result.actual_output,
metrics=eval_result.metrics,
metadata=result.metadata,
error=result.error,
)
)
else:
final_results.append(result)
return final_results
@staticmethod
def _extract_output(response: Mapping[str, Any]) -> str:
"""Extract text output from the blocking workflow response.
The blocking response ``data.outputs`` is a dict of output variables.
We take the first value as the primary output text.
"""
if "data" in response and isinstance(response["data"], Mapping):
outputs = response["data"].get("outputs", {})
if isinstance(outputs, Mapping):
values = list(outputs.values())
return str(values[0]) if values else ""
return str(outputs)
return str(response)
@staticmethod
def _extract_workflow_run_id(response: Mapping[str, Any]) -> str | None:
"""Extract workflow_run_id from the blocking response.
The blocking response has ``workflow_run_id`` at the top level and
also ``data.id`` (same value).
"""
wf_run_id = response.get("workflow_run_id")
if wf_run_id:
return str(wf_run_id)
# Fallback to data.id
data = response.get("data")
if isinstance(data, Mapping) and data.get("id"):
return str(data["id"])
return None
def _query_node_executions(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> list[dict[str, Any]]:
"""Query per-node execution records from the DB after workflow completes.
Node executions are persisted during workflow execution. We read them
back via the ``workflow_run_id`` to get each node's inputs, outputs,
status, elapsed_time, etc.
Returns a list of serialisable dicts for storage in ``metadata``.
"""
stmt = WorkflowNodeExecutionModel.preload_offload_data(
select(WorkflowNodeExecutionModel)
).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
).order_by(asc(WorkflowNodeExecutionModel.created_at))
node_models: Sequence[WorkflowNodeExecutionModel] = (
self.session.execute(stmt).scalars().all()
)
return [self._serialize_node_execution(node) for node in node_models]
@staticmethod
def _serialize_node_execution(node: WorkflowNodeExecutionModel) -> dict[str, Any]:
"""Convert a WorkflowNodeExecutionModel to a serialisable dict.
Includes the node's id, type, title, inputs/outputs (parsed from JSON),
status, error, and elapsed_time. The virtual Start node injected by
SnippetGenerateService is filtered out by the caller if needed.
"""
def _safe_parse_json(value: str | None) -> Any:
if not value:
return None
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return {
"id": node.id,
"node_id": node.node_id,
"node_type": node.node_type,
"title": node.title,
"inputs": _safe_parse_json(node.inputs),
"outputs": _safe_parse_json(node.outputs),
"status": node.status,
"error": node.error,
"elapsed_time": node.elapsed_time,
}

View File

@@ -131,6 +131,48 @@ class SnippetGenerateService:
)
)
@classmethod
def run_published(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
) -> Mapping[str, Any]:
"""
Run a snippet's published workflow in non-streaming (blocking) mode.
Similar to :meth:`generate` but targets the published workflow instead
of the draft, and returns the raw blocking response without SSE
wrapping. Designed for programmatic callers such as evaluation runners.
:param snippet: CustomizedSnippet instance (must be published)
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation
:return: Blocking response mapping with workflow outputs
:raises ValueError: If the snippet has no published workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet)
if not workflow:
raise ValueError("No published workflow found for snippet")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
app_proxy = _SnippetAsApp(snippet)
response: Mapping[str, Any] = WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
)
return response
@classmethod
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""

View File

@@ -18,6 +18,7 @@ from core.evaluation.evaluation_manager import EvaluationManager
from core.evaluation.runners.agent_evaluation_runner import AgentEvaluationRunner
from core.evaluation.runners.llm_evaluation_runner import LLMEvaluationRunner
from core.evaluation.runners.retrieval_evaluation_runner import RetrievalEvaluationRunner
from core.evaluation.runners.snippet_evaluation_runner import SnippetEvaluationRunner
from core.evaluation.runners.workflow_evaluation_runner import WorkflowEvaluationRunner
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -128,6 +129,8 @@ def _create_runner(
return AgentEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.WORKFLOW:
return WorkflowEvaluationRunner(evaluation_instance, session)
case EvaluationCategory.SNIPPET:
return SnippetEvaluationRunner(evaluation_instance, session)
case _:
raise ValueError(f"Unknown evaluation category: {category}")