Files
dify/api/services/evaluation_service.py
2026-03-04 19:20:08 +08:00

502 lines
17 KiB
Python

import io
import json
import logging
from typing import Any, Union
from openpyxl import Workbook, load_workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from sqlalchemy.orm import Session
from configs import dify_config
from core.evaluation.entities.evaluation_entity import (
DefaultMetric,
EvaluationCategory,
EvaluationConfigData,
EvaluationItemInput,
EvaluationRunData,
EvaluationRunRequest,
)
from core.evaluation.evaluation_manager import EvaluationManager
from models.evaluation import (
EvaluationConfiguration,
EvaluationRun,
EvaluationRunItem,
EvaluationRunStatus,
)
from models.model import App, AppMode
from models.snippet import CustomizedSnippet
from services.errors.evaluation import (
EvaluationDatasetInvalidError,
EvaluationFrameworkNotConfiguredError,
EvaluationMaxConcurrentRunsError,
EvaluationNotFoundError,
)
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class EvaluationService:
"""
Service for evaluation-related operations.
Provides functionality to generate evaluation dataset templates
based on App or Snippet input parameters.
"""
# Excluded app modes that don't support evaluation templates
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
@classmethod
def generate_dataset_template(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> tuple[bytes, str]:
"""
Generate evaluation dataset template as XLSX bytes.
Creates an XLSX file with headers based on the evaluation target's input parameters.
The first column is index, followed by input parameter columns.
:param target: App or CustomizedSnippet instance
:param target_type: Target type string ("app" or "snippet")
:return: Tuple of (xlsx_content_bytes, filename)
:raises ValueError: If target type is not supported or app mode is excluded
"""
# Validate target type
if target_type == "app":
if not isinstance(target, App):
raise ValueError("Invalid target: expected App instance")
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
input_fields = cls._get_app_input_fields(target)
elif target_type == "snippet":
if not isinstance(target, CustomizedSnippet):
raise ValueError("Invalid target: expected CustomizedSnippet instance")
input_fields = cls._get_snippet_input_fields(target)
else:
raise ValueError(f"Unsupported target type: {target_type}")
# Generate XLSX template
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
# Build filename
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
filename = f"{truncated_name}-evaluation-dataset.xlsx"
return xlsx_content, filename
@classmethod
def _get_app_input_fields(cls, app: App) -> list[dict]:
"""
Get input fields from App's workflow.
:param app: App instance
:return: List of input field definitions
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=app)
if not workflow:
return []
# Get user input form from workflow
user_input_form = workflow.user_input_form()
return user_input_form
@classmethod
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
"""
Get input fields from Snippet.
Tries to get from snippet's own input_fields first,
then falls back to workflow's user_input_form.
:param snippet: CustomizedSnippet instance
:return: List of input field definitions
"""
# Try snippet's own input_fields first
input_fields = snippet.input_fields_list
if input_fields:
return input_fields
# Fallback to workflow's user_input_form
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow:
return workflow.user_input_form()
return []
@classmethod
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
"""
Generate XLSX template file content.
Creates a workbook with:
- First row as header row with "index" and input field names
- Styled header with background color and borders
- Empty data rows ready for user input
:param input_fields: List of input field definitions
:param target_name: Name of the target (for sheet name)
:return: XLSX file content as bytes
"""
wb = Workbook()
ws = wb.active
sheet_name = "Evaluation Dataset"
ws.title = sheet_name
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Build header row
headers = ["index"]
for field in input_fields:
field_label = field.get("label") or field.get("variable")
headers.append(field_label)
# Write header row
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# Set column widths
ws.column_dimensions["A"].width = 10 # index column
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 20
# Add one empty row with row number for user reference
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
# Save to bytes
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()
# ---- Evaluation Configuration CRUD ----
@classmethod
def get_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
) -> EvaluationConfiguration | None:
return (
session.query(EvaluationConfiguration)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.first()
)
@classmethod
def save_evaluation_config(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
data: EvaluationConfigData,
) -> EvaluationConfiguration:
config = cls.get_evaluation_config(session, tenant_id, target_type, target_id)
if config is None:
config = EvaluationConfiguration(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
created_by=account_id,
updated_by=account_id,
)
session.add(config)
config.evaluation_model_provider = data.evaluation_model_provider
config.evaluation_model = data.evaluation_model
config.metrics_config = json.dumps({
"default_metrics": [m.model_dump() for m in data.default_metrics],
"customized_metrics": data.customized_metrics.model_dump() if data.customized_metrics else None,
})
config.judgement_conditions = json.dumps(
data.judgment_config.model_dump() if data.judgment_config else {}
)
config.updated_by = account_id
session.commit()
session.refresh(config)
return config
# ---- Evaluation Run Management ----
@classmethod
def start_evaluation_run(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
account_id: str,
dataset_file_content: bytes,
run_request: EvaluationRunRequest,
) -> EvaluationRun:
"""Validate dataset, create run record, dispatch Celery task.
Saves the provided parameters as the latest EvaluationConfiguration
before creating the run.
"""
# Check framework is configured
evaluation_instance = EvaluationManager.get_evaluation_instance()
if evaluation_instance is None:
raise EvaluationFrameworkNotConfiguredError()
# Derive evaluation_category from default_metrics node types
evaluation_category = cls._resolve_evaluation_category(run_request.default_metrics)
# Save as latest EvaluationConfiguration
config = cls.save_evaluation_config(
session=session,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
account_id=account_id,
data=run_request,
)
# Check concurrent run limit
active_runs = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id)
.filter(EvaluationRun.status.in_([EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING]))
.count()
)
max_concurrent = dify_config.EVALUATION_MAX_CONCURRENT_RUNS
if active_runs >= max_concurrent:
raise EvaluationMaxConcurrentRunsError(
f"Maximum concurrent runs ({max_concurrent}) reached."
)
# Parse dataset
items = cls._parse_dataset(dataset_file_content)
max_rows = dify_config.EVALUATION_MAX_DATASET_ROWS
if len(items) > max_rows:
raise EvaluationDatasetInvalidError(f"Dataset has {len(items)} rows, max is {max_rows}.")
# Create evaluation run
evaluation_run = EvaluationRun(
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_config_id=config.id,
status=EvaluationRunStatus.PENDING,
total_items=len(items),
created_by=account_id,
)
session.add(evaluation_run)
session.commit()
session.refresh(evaluation_run)
# Build Celery task data
run_data = EvaluationRunData(
evaluation_run_id=evaluation_run.id,
tenant_id=tenant_id,
target_type=target_type,
target_id=target_id,
evaluation_category=evaluation_category,
evaluation_model_provider=run_request.evaluation_model_provider,
evaluation_model=run_request.evaluation_model,
default_metrics=[m.model_dump() for m in run_request.default_metrics],
customized_metrics=(
run_request.customized_metrics.model_dump() if run_request.customized_metrics else None
),
judgment_config=run_request.judgment_config,
items=items,
)
# Dispatch Celery task
from tasks.evaluation_task import run_evaluation
task = run_evaluation.delay(run_data.model_dump())
evaluation_run.celery_task_id = task.id
session.commit()
return evaluation_run
@classmethod
def get_evaluation_runs(
cls,
session: Session,
tenant_id: str,
target_type: str,
target_id: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[EvaluationRun], int]:
"""Query evaluation run history with pagination."""
query = (
session.query(EvaluationRun)
.filter_by(tenant_id=tenant_id, target_type=target_type, target_id=target_id)
.order_by(EvaluationRun.created_at.desc())
)
total = query.count()
runs = query.offset((page - 1) * page_size).limit(page_size).all()
return runs, total
@classmethod
def get_evaluation_run_detail(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = (
session.query(EvaluationRun)
.filter_by(id=run_id, tenant_id=tenant_id)
.first()
)
if not run:
raise EvaluationNotFoundError("Evaluation run not found.")
return run
@classmethod
def get_evaluation_run_items(
cls,
session: Session,
run_id: str,
page: int = 1,
page_size: int = 50,
) -> tuple[list[EvaluationRunItem], int]:
"""Query evaluation run items with pagination."""
query = (
session.query(EvaluationRunItem)
.filter_by(evaluation_run_id=run_id)
.order_by(EvaluationRunItem.item_index.asc())
)
total = query.count()
items = query.offset((page - 1) * page_size).limit(page_size).all()
return items, total
@classmethod
def cancel_evaluation_run(
cls,
session: Session,
tenant_id: str,
run_id: str,
) -> EvaluationRun:
run = cls.get_evaluation_run_detail(session, tenant_id, run_id)
if run.status not in (EvaluationRunStatus.PENDING, EvaluationRunStatus.RUNNING):
raise ValueError(f"Cannot cancel evaluation run in status: {run.status}")
run.status = EvaluationRunStatus.CANCELLED
# Revoke Celery task if running
if run.celery_task_id:
try:
from celery import current_app as celery_app
celery_app.control.revoke(run.celery_task_id, terminate=True)
except Exception:
logger.exception("Failed to revoke Celery task %s", run.celery_task_id)
session.commit()
return run
@classmethod
def get_supported_metrics(cls, category: EvaluationCategory) -> list[str]:
return EvaluationManager.get_supported_metrics(category)
# ---- Category Resolution ----
@classmethod
def _resolve_evaluation_category(cls, default_metrics: list[DefaultMetric]) -> EvaluationCategory:
"""Derive evaluation category from default_metrics node_info types.
Uses the type of the first node_info found in default_metrics.
Falls back to LLM if no metrics are provided.
"""
for metric in default_metrics:
for node_info in metric.node_info_list:
try:
return EvaluationCategory(node_info.type)
except ValueError:
continue
return EvaluationCategory.LLM
# ---- Dataset Parsing ----
@classmethod
def _parse_dataset(cls, xlsx_content: bytes) -> list[EvaluationItemInput]:
"""Parse evaluation dataset from XLSX bytes."""
wb = load_workbook(io.BytesIO(xlsx_content), read_only=True)
ws = wb.active
if ws is None:
raise EvaluationDatasetInvalidError("XLSX file has no active worksheet.")
rows = list(ws.iter_rows(values_only=True))
if len(rows) < 2:
raise EvaluationDatasetInvalidError("Dataset must have at least a header row and one data row.")
headers = [str(h).strip() if h is not None else "" for h in rows[0]]
if not headers or headers[0].lower() != "index":
raise EvaluationDatasetInvalidError("First column header must be 'index'.")
input_headers = headers[1:] # Skip 'index'
items = []
for row_idx, row in enumerate(rows[1:], start=1):
values = list(row)
if all(v is None or str(v).strip() == "" for v in values):
continue # Skip empty rows
index_val = values[0] if values else row_idx
try:
index = int(index_val)
except (TypeError, ValueError):
index = row_idx
inputs: dict[str, Any] = {}
for col_idx, header in enumerate(input_headers):
val = values[col_idx + 1] if col_idx + 1 < len(values) else None
inputs[header] = str(val) if val is not None else ""
# Check for expected_output column
expected_output = inputs.pop("expected_output", None)
context_str = inputs.pop("context", None)
context = context_str.split(";") if context_str else None
items.append(
EvaluationItemInput(
index=index,
inputs=inputs,
expected_output=expected_output,
context=context,
)
)
wb.close()
return items