mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
feat: add GraphEngine layer node execution hooks (#28583)
This commit is contained in:
@@ -140,6 +140,10 @@ class GraphEngine:
|
|||||||
pause_handler = PauseCommandHandler()
|
pause_handler = PauseCommandHandler()
|
||||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||||
|
|
||||||
|
# === Extensibility ===
|
||||||
|
# Layers allow plugins to extend engine functionality
|
||||||
|
self._layers: list[GraphEngineLayer] = []
|
||||||
|
|
||||||
# === Worker Pool Setup ===
|
# === Worker Pool Setup ===
|
||||||
# Capture Flask app context for worker threads
|
# Capture Flask app context for worker threads
|
||||||
flask_app: Flask | None = None
|
flask_app: Flask | None = None
|
||||||
@@ -158,6 +162,7 @@ class GraphEngine:
|
|||||||
ready_queue=self._ready_queue,
|
ready_queue=self._ready_queue,
|
||||||
event_queue=self._event_queue,
|
event_queue=self._event_queue,
|
||||||
graph=self._graph,
|
graph=self._graph,
|
||||||
|
layers=self._layers,
|
||||||
flask_app=flask_app,
|
flask_app=flask_app,
|
||||||
context_vars=context_vars,
|
context_vars=context_vars,
|
||||||
min_workers=self._min_workers,
|
min_workers=self._min_workers,
|
||||||
@@ -196,10 +201,6 @@ class GraphEngine:
|
|||||||
event_emitter=self._event_manager,
|
event_emitter=self._event_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# === Extensibility ===
|
|
||||||
# Layers allow plugins to extend engine functionality
|
|
||||||
self._layers: list[GraphEngineLayer] = []
|
|
||||||
|
|
||||||
# === Validation ===
|
# === Validation ===
|
||||||
# Ensure all nodes share the same GraphRuntimeState instance
|
# Ensure all nodes share the same GraphRuntimeState instance
|
||||||
self._validate_graph_state_consistency()
|
self._validate_graph_state_consistency()
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
|
|||||||
from .base import GraphEngineLayer
|
from .base import GraphEngineLayer
|
||||||
from .debug_logging import DebugLoggingLayer
|
from .debug_logging import DebugLoggingLayer
|
||||||
from .execution_limits import ExecutionLimitsLayer
|
from .execution_limits import ExecutionLimitsLayer
|
||||||
|
from .observability import ObservabilityLayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DebugLoggingLayer",
|
"DebugLoggingLayer",
|
||||||
"ExecutionLimitsLayer",
|
"ExecutionLimitsLayer",
|
||||||
"GraphEngineLayer",
|
"GraphEngineLayer",
|
||||||
|
"ObservabilityLayer",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||||
from core.workflow.graph_events import GraphEngineEvent
|
from core.workflow.graph_events import GraphEngineEvent
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
@@ -83,3 +84,29 @@ class GraphEngineLayer(ABC):
|
|||||||
error: The exception that caused execution to fail, or None if successful
|
error: The exception that caused execution to fail, or None if successful
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_node_run_start(self, node: Node) -> None: # noqa: B027
|
||||||
|
"""
|
||||||
|
Called immediately before a node begins execution.
|
||||||
|
|
||||||
|
Layers can override to inject behavior (e.g., start spans) prior to node execution.
|
||||||
|
The node's execution ID is available via `node._node_execution_id` and will be
|
||||||
|
consistent with all events emitted by this node execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node instance about to be executed
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
|
||||||
|
"""
|
||||||
|
Called after a node finishes execution.
|
||||||
|
|
||||||
|
The node's execution ID is available via `node._node_execution_id` and matches
|
||||||
|
the `id` field in all events emitted by this node execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node instance that just finished execution
|
||||||
|
error: Exception instance if the node failed, otherwise None
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
61
api/core/workflow/graph_engine/layers/node_parsers.py
Normal file
61
api/core/workflow/graph_engine/layers/node_parsers.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Node-level OpenTelemetry parser interfaces and defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from opentelemetry.trace import Span
|
||||||
|
from opentelemetry.trace.status import Status, StatusCode
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class NodeOTelParser(Protocol):
|
||||||
|
"""Parser interface for node-specific OpenTelemetry enrichment."""
|
||||||
|
|
||||||
|
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultNodeOTelParser:
|
||||||
|
"""Fallback parser used when no node-specific parser is registered."""
|
||||||
|
|
||||||
|
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||||
|
span.set_attribute("node.id", node.id)
|
||||||
|
if node.execution_id:
|
||||||
|
span.set_attribute("node.execution_id", node.execution_id)
|
||||||
|
if hasattr(node, "node_type") and node.node_type:
|
||||||
|
span.set_attribute("node.type", node.node_type.value)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
span.record_exception(error)
|
||||||
|
span.set_status(Status(StatusCode.ERROR, str(error)))
|
||||||
|
else:
|
||||||
|
span.set_status(Status(StatusCode.OK))
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNodeOTelParser:
|
||||||
|
"""Parser for tool nodes that captures tool-specific metadata."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._delegate = DefaultNodeOTelParser()
|
||||||
|
|
||||||
|
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
|
||||||
|
self._delegate.parse(node=node, span=span, error=error)
|
||||||
|
|
||||||
|
tool_data = getattr(node, "_node_data", None)
|
||||||
|
if not isinstance(tool_data, ToolNodeData):
|
||||||
|
return
|
||||||
|
|
||||||
|
span.set_attribute("tool.provider.id", tool_data.provider_id)
|
||||||
|
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
|
||||||
|
span.set_attribute("tool.provider.name", tool_data.provider_name)
|
||||||
|
span.set_attribute("tool.name", tool_data.tool_name)
|
||||||
|
span.set_attribute("tool.label", tool_data.tool_label)
|
||||||
|
if tool_data.plugin_unique_identifier:
|
||||||
|
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
|
||||||
|
if tool_data.credential_id:
|
||||||
|
span.set_attribute("tool.credential.id", tool_data.credential_id)
|
||||||
|
if tool_data.tool_configurations:
|
||||||
|
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))
|
||||||
169
api/core/workflow/graph_engine/layers/observability.py
Normal file
169
api/core/workflow/graph_engine/layers/observability.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Observability layer for GraphEngine.
|
||||||
|
|
||||||
|
This layer creates OpenTelemetry spans for node execution, enabling distributed
|
||||||
|
tracing of workflow execution. It establishes OTel context during node execution
|
||||||
|
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
|
||||||
|
associates with the node span.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import cast, final
|
||||||
|
|
||||||
|
from opentelemetry import context as context_api
|
||||||
|
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
|
from core.workflow.graph_engine.layers.node_parsers import (
|
||||||
|
DefaultNodeOTelParser,
|
||||||
|
NodeOTelParser,
|
||||||
|
ToolNodeOTelParser,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _NodeSpanContext:
|
||||||
|
span: "Span"
|
||||||
|
token: object
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class ObservabilityLayer(GraphEngineLayer):
|
||||||
|
"""
|
||||||
|
Layer that creates OpenTelemetry spans for node execution.
|
||||||
|
|
||||||
|
This layer:
|
||||||
|
- Creates a span when a node starts execution
|
||||||
|
- Establishes OTel context so automatic instrumentation associates with the span
|
||||||
|
- Sets complete attributes and status when node execution ends
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._node_contexts: dict[str, _NodeSpanContext] = {}
|
||||||
|
self._parsers: dict[NodeType, NodeOTelParser] = {}
|
||||||
|
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
|
||||||
|
self._is_disabled: bool = False
|
||||||
|
self._tracer: Tracer | None = None
|
||||||
|
self._build_parser_registry()
|
||||||
|
self._init_tracer()
|
||||||
|
|
||||||
|
def _init_tracer(self) -> None:
|
||||||
|
"""Initialize OpenTelemetry tracer in constructor."""
|
||||||
|
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
|
||||||
|
self._is_disabled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._tracer = get_tracer(__name__)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
|
||||||
|
self._is_disabled = True
|
||||||
|
|
||||||
|
def _build_parser_registry(self) -> None:
|
||||||
|
"""Initialize parser registry for node types."""
|
||||||
|
self._parsers = {
|
||||||
|
NodeType.TOOL: ToolNodeOTelParser(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_parser(self, node: Node) -> NodeOTelParser:
|
||||||
|
node_type = getattr(node, "node_type", None)
|
||||||
|
if isinstance(node_type, NodeType):
|
||||||
|
return self._parsers.get(node_type, self._default_parser)
|
||||||
|
return self._default_parser
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_graph_start(self) -> None:
|
||||||
|
"""Called when graph execution starts."""
|
||||||
|
self._node_contexts.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_node_run_start(self, node: Node) -> None:
|
||||||
|
"""
|
||||||
|
Called when a node starts execution.
|
||||||
|
|
||||||
|
Creates a span and establishes OTel context for automatic instrumentation.
|
||||||
|
"""
|
||||||
|
if self._is_disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self._tracer:
|
||||||
|
return
|
||||||
|
|
||||||
|
execution_id = node.execution_id
|
||||||
|
if not execution_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
parent_context = context_api.get_current()
|
||||||
|
span = self._tracer.start_span(
|
||||||
|
f"{node.title}",
|
||||||
|
kind=SpanKind.INTERNAL,
|
||||||
|
context=parent_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_context = set_span_in_context(span)
|
||||||
|
token = context_api.attach(new_context)
|
||||||
|
|
||||||
|
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
|
||||||
|
"""
|
||||||
|
Called when a node finishes execution.
|
||||||
|
|
||||||
|
Sets complete attributes, records exceptions, and ends the span.
|
||||||
|
"""
|
||||||
|
if self._is_disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
execution_id = node.execution_id
|
||||||
|
if not execution_id:
|
||||||
|
return
|
||||||
|
node_context = self._node_contexts.get(execution_id)
|
||||||
|
if not node_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
span = node_context.span
|
||||||
|
parser = self._get_parser(node)
|
||||||
|
try:
|
||||||
|
parser.parse(node=node, span=span, error=error)
|
||||||
|
span.end()
|
||||||
|
finally:
|
||||||
|
token = node_context.token
|
||||||
|
if token is not None:
|
||||||
|
try:
|
||||||
|
context_api.detach(token)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to detach OpenTelemetry token: %s", token)
|
||||||
|
self._node_contexts.pop(execution_id, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_event(self, event) -> None:
|
||||||
|
"""Not used in this layer."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
|
"""Called when graph execution ends."""
|
||||||
|
if self._node_contexts:
|
||||||
|
logger.warning(
|
||||||
|
"ObservabilityLayer: %d node spans were not properly ended",
|
||||||
|
len(self._node_contexts),
|
||||||
|
)
|
||||||
|
self._node_contexts.clear()
|
||||||
@@ -9,6 +9,7 @@ import contextvars
|
|||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import final
|
from typing import final
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -17,6 +18,7 @@ from flask import Flask
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
@@ -39,6 +41,7 @@ class Worker(threading.Thread):
|
|||||||
ready_queue: ReadyQueue,
|
ready_queue: ReadyQueue,
|
||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
|
layers: Sequence[GraphEngineLayer],
|
||||||
worker_id: int = 0,
|
worker_id: int = 0,
|
||||||
flask_app: Flask | None = None,
|
flask_app: Flask | None = None,
|
||||||
context_vars: contextvars.Context | None = None,
|
context_vars: contextvars.Context | None = None,
|
||||||
@@ -50,6 +53,7 @@ class Worker(threading.Thread):
|
|||||||
ready_queue: Ready queue containing node IDs ready for execution
|
ready_queue: Ready queue containing node IDs ready for execution
|
||||||
event_queue: Queue for pushing execution events
|
event_queue: Queue for pushing execution events
|
||||||
graph: Graph containing nodes to execute
|
graph: Graph containing nodes to execute
|
||||||
|
layers: Graph engine layers for node execution hooks
|
||||||
worker_id: Unique identifier for this worker
|
worker_id: Unique identifier for this worker
|
||||||
flask_app: Optional Flask application for context preservation
|
flask_app: Optional Flask application for context preservation
|
||||||
context_vars: Optional context variables to preserve in worker thread
|
context_vars: Optional context variables to preserve in worker thread
|
||||||
@@ -63,6 +67,7 @@ class Worker(threading.Thread):
|
|||||||
self._context_vars = context_vars
|
self._context_vars = context_vars
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._last_task_time = time.time()
|
self._last_task_time = time.time()
|
||||||
|
self._layers = layers if layers is not None else []
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Signal the worker to stop processing."""
|
"""Signal the worker to stop processing."""
|
||||||
@@ -122,20 +127,51 @@ class Worker(threading.Thread):
|
|||||||
Args:
|
Args:
|
||||||
node: The node instance to execute
|
node: The node instance to execute
|
||||||
"""
|
"""
|
||||||
# Execute the node with preserved context if Flask app is provided
|
node.ensure_execution_id()
|
||||||
|
|
||||||
|
error: Exception | None = None
|
||||||
|
|
||||||
if self._flask_app and self._context_vars:
|
if self._flask_app and self._context_vars:
|
||||||
with preserve_flask_contexts(
|
with preserve_flask_contexts(
|
||||||
flask_app=self._flask_app,
|
flask_app=self._flask_app,
|
||||||
context_vars=self._context_vars,
|
context_vars=self._context_vars,
|
||||||
):
|
):
|
||||||
# Execute the node
|
self._invoke_node_run_start_hooks(node)
|
||||||
|
try:
|
||||||
|
node_events = node.run()
|
||||||
|
for event in node_events:
|
||||||
|
self._event_queue.put(event)
|
||||||
|
except Exception as exc:
|
||||||
|
error = exc
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._invoke_node_run_end_hooks(node, error)
|
||||||
|
else:
|
||||||
|
self._invoke_node_run_start_hooks(node)
|
||||||
|
try:
|
||||||
node_events = node.run()
|
node_events = node.run()
|
||||||
for event in node_events:
|
for event in node_events:
|
||||||
# Forward event to dispatcher immediately for streaming
|
|
||||||
self._event_queue.put(event)
|
self._event_queue.put(event)
|
||||||
else:
|
except Exception as exc:
|
||||||
# Execute without context preservation
|
error = exc
|
||||||
node_events = node.run()
|
raise
|
||||||
for event in node_events:
|
finally:
|
||||||
# Forward event to dispatcher immediately for streaming
|
self._invoke_node_run_end_hooks(node, error)
|
||||||
self._event_queue.put(event)
|
|
||||||
|
def _invoke_node_run_start_hooks(self, node: Node) -> None:
|
||||||
|
"""Invoke on_node_run_start hooks for all layers."""
|
||||||
|
for layer in self._layers:
|
||||||
|
try:
|
||||||
|
layer.on_node_run_start(node)
|
||||||
|
except Exception:
|
||||||
|
# Silently ignore layer errors to prevent disrupting node execution
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
|
||||||
|
"""Invoke on_node_run_end hooks for all layers."""
|
||||||
|
for layer in self._layers:
|
||||||
|
try:
|
||||||
|
layer.on_node_run_end(node, error)
|
||||||
|
except Exception:
|
||||||
|
# Silently ignore layer errors to prevent disrupting node execution
|
||||||
|
continue
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from configs import dify_config
|
|||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_events import GraphNodeEventBase
|
from core.workflow.graph_events import GraphNodeEventBase
|
||||||
|
|
||||||
|
from ..layers.base import GraphEngineLayer
|
||||||
from ..ready_queue import ReadyQueue
|
from ..ready_queue import ReadyQueue
|
||||||
from ..worker import Worker
|
from ..worker import Worker
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ class WorkerPool:
|
|||||||
ready_queue: ReadyQueue,
|
ready_queue: ReadyQueue,
|
||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
|
layers: list[GraphEngineLayer],
|
||||||
flask_app: "Flask | None" = None,
|
flask_app: "Flask | None" = None,
|
||||||
context_vars: "Context | None" = None,
|
context_vars: "Context | None" = None,
|
||||||
min_workers: int | None = None,
|
min_workers: int | None = None,
|
||||||
@@ -53,6 +55,7 @@ class WorkerPool:
|
|||||||
ready_queue: Ready queue for nodes ready for execution
|
ready_queue: Ready queue for nodes ready for execution
|
||||||
event_queue: Queue for worker events
|
event_queue: Queue for worker events
|
||||||
graph: The workflow graph
|
graph: The workflow graph
|
||||||
|
layers: Graph engine layers for node execution hooks
|
||||||
flask_app: Optional Flask app for context preservation
|
flask_app: Optional Flask app for context preservation
|
||||||
context_vars: Optional context variables
|
context_vars: Optional context variables
|
||||||
min_workers: Minimum number of workers
|
min_workers: Minimum number of workers
|
||||||
@@ -65,6 +68,7 @@ class WorkerPool:
|
|||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._flask_app = flask_app
|
self._flask_app = flask_app
|
||||||
self._context_vars = context_vars
|
self._context_vars = context_vars
|
||||||
|
self._layers = layers
|
||||||
|
|
||||||
# Scaling parameters with defaults
|
# Scaling parameters with defaults
|
||||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||||
@@ -144,6 +148,7 @@ class WorkerPool:
|
|||||||
ready_queue=self._ready_queue,
|
ready_queue=self._ready_queue,
|
||||||
event_queue=self._event_queue,
|
event_queue=self._event_queue,
|
||||||
graph=self._graph,
|
graph=self._graph,
|
||||||
|
layers=self._layers,
|
||||||
worker_id=worker_id,
|
worker_id=worker_id,
|
||||||
flask_app=self._flask_app,
|
flask_app=self._flask_app,
|
||||||
context_vars=self._context_vars,
|
context_vars=self._context_vars,
|
||||||
|
|||||||
@@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]):
|
|||||||
def graph_init_params(self) -> "GraphInitParams":
|
def graph_init_params(self) -> "GraphInitParams":
|
||||||
return self._graph_init_params
|
return self._graph_init_params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def execution_id(self) -> str:
|
||||||
|
return self._node_execution_id
|
||||||
|
|
||||||
|
def ensure_execution_id(self) -> str:
|
||||||
|
if not self._node_execution_id:
|
||||||
|
self._node_execution_id = str(uuid4())
|
||||||
|
return self._node_execution_id
|
||||||
|
|
||||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||||
|
|
||||||
@@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||||
# Generate a single node execution ID to use for all events
|
execution_id = self.ensure_execution_id()
|
||||||
if not self._node_execution_id:
|
|
||||||
self._node_execution_id = str(uuid4())
|
|
||||||
self._start_at = naive_utc_now()
|
self._start_at = naive_utc_now()
|
||||||
|
|
||||||
# Create and push start event with required fields
|
# Create and push start event with required fields
|
||||||
start_event = NodeRunStartedEvent(
|
start_event = NodeRunStartedEvent(
|
||||||
id=self._node_execution_id,
|
id=execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.title,
|
node_title=self.title,
|
||||||
@@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
yield self._dispatch(event)
|
yield self._dispatch(event)
|
||||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
event.id = self._node_execution_id
|
event.id = self.execution_id
|
||||||
yield event
|
yield event
|
||||||
else:
|
else:
|
||||||
yield event
|
yield event
|
||||||
@@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
error_type="WorkflowNodeError",
|
error_type="WorkflowNodeError",
|
||||||
)
|
)
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
@@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
match result.status:
|
match result.status:
|
||||||
case WorkflowNodeExecutionStatus.FAILED:
|
case WorkflowNodeExecutionStatus.FAILED:
|
||||||
return NodeRunFailedEvent(
|
return NodeRunFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
@@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
)
|
)
|
||||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
return NodeRunSucceededEvent(
|
return NodeRunSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
@@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||||
return NodeRunStreamChunkEvent(
|
return NodeRunStreamChunkEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
selector=event.selector,
|
selector=event.selector,
|
||||||
@@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
match event.node_run_result.status:
|
match event.node_run_result.status:
|
||||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
return NodeRunSucceededEvent(
|
return NodeRunSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
@@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
)
|
)
|
||||||
case WorkflowNodeExecutionStatus.FAILED:
|
case WorkflowNodeExecutionStatus.FAILED:
|
||||||
return NodeRunFailedEvent(
|
return NodeRunFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
@@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||||
return NodeRunPauseRequestedEvent(
|
return NodeRunPauseRequestedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||||
@@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||||
return NodeRunAgentLogEvent(
|
return NodeRunAgentLogEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
message_id=event.message_id,
|
message_id=event.message_id,
|
||||||
@@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||||
return NodeRunLoopStartedEvent(
|
return NodeRunLoopStartedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||||
return NodeRunLoopNextEvent(
|
return NodeRunLoopNextEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||||
return NodeRunLoopSucceededEvent(
|
return NodeRunLoopSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||||
return NodeRunLoopFailedEvent(
|
return NodeRunLoopFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||||
return NodeRunIterationStartedEvent(
|
return NodeRunIterationStartedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||||
return NodeRunIterationNextEvent(
|
return NodeRunIterationNextEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||||
return NodeRunIterationSucceededEvent(
|
return NodeRunIterationSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||||
return NodeRunIterationFailedEvent(
|
return NodeRunIterationFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
node_title=self.node_data.title,
|
node_title=self.node_data.title,
|
||||||
@@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||||
return NodeRunRetrieverResourceEvent(
|
return NodeRunRetrieverResourceEvent(
|
||||||
id=self._node_execution_id,
|
id=self.execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
retriever_resources=event.retriever_resources,
|
retriever_resources=event.retriever_resources,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
|
|||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
|
||||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
@@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
|||||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||||
|
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
@@ -98,6 +99,10 @@ class WorkflowEntry:
|
|||||||
)
|
)
|
||||||
self.graph_engine.layer(limits_layer)
|
self.graph_engine.layer(limits_layer)
|
||||||
|
|
||||||
|
# Add observability layer when OTel is enabled
|
||||||
|
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||||
|
self.graph_engine.layer(ObservabilityLayer())
|
||||||
|
|
||||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||||
graph_engine = self.graph_engine
|
graph_engine = self.graph_engine
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
@@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.otel.decorators.handler import SpanHandler
|
from extensions.otel.decorators.handler import SpanHandler
|
||||||
|
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
|
||||||
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
|
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
|
||||||
|
|
||||||
|
|
||||||
def _is_instrument_flag_enabled() -> bool:
|
|
||||||
"""
|
|
||||||
Check if external instrumentation is enabled via environment variable.
|
|
||||||
|
|
||||||
Third-party non-invasive instrumentation agents set this flag to coordinate
|
|
||||||
with Dify's manual OpenTelemetry instrumentation.
|
|
||||||
"""
|
|
||||||
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
|
def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
|
||||||
"""Get or create a singleton instance of the handler class."""
|
"""Get or create a singleton instance of the handler class."""
|
||||||
if handler_class not in _HANDLER_INSTANCES:
|
if handler_class not in _HANDLER_INSTANCES:
|
||||||
@@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
|
|||||||
def decorator(func: T) -> T:
|
def decorator(func: T) -> T:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()):
|
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
handler = _get_handler_instance(handler_class or SpanHandler)
|
handler = _get_handler_instance(handler_class or SpanHandler)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs):
|
|||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.info("Initializing OpenTelemetry for Celery worker")
|
logger.info("Initializing OpenTelemetry for Celery worker")
|
||||||
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
|
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
|
||||||
|
|
||||||
|
|
||||||
|
def is_instrument_flag_enabled() -> bool:
|
||||||
|
"""
|
||||||
|
Check if external instrumentation is enabled via environment variable.
|
||||||
|
|
||||||
|
Third-party non-invasive instrumentation agents set this flag to coordinate
|
||||||
|
with Dify's manual OpenTelemetry instrumentation.
|
||||||
|
"""
|
||||||
|
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Shared fixtures for ObservabilityLayer tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||||
|
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||||
|
from opentelemetry.trace import set_tracer_provider
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory_span_exporter():
|
||||||
|
"""Provide an in-memory span exporter for testing."""
|
||||||
|
return InMemorySpanExporter()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tracer_provider_with_memory_exporter(memory_span_exporter):
|
||||||
|
"""Provide a TracerProvider configured with memory exporter."""
|
||||||
|
import opentelemetry.trace as trace_api
|
||||||
|
|
||||||
|
trace_api._TRACER_PROVIDER = None
|
||||||
|
trace_api._TRACER_PROVIDER_SET_ONCE._done = False
|
||||||
|
|
||||||
|
provider = TracerProvider()
|
||||||
|
processor = SimpleSpanProcessor(memory_span_exporter)
|
||||||
|
provider.add_span_processor(processor)
|
||||||
|
set_tracer_provider(provider)
|
||||||
|
|
||||||
|
yield provider
|
||||||
|
|
||||||
|
provider.force_flush()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_start_node():
|
||||||
|
"""Create a mock Start Node."""
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "test-start-node-id"
|
||||||
|
node.title = "Start Node"
|
||||||
|
node.execution_id = "test-start-execution-id"
|
||||||
|
node.node_type = NodeType.START
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_node():
|
||||||
|
"""Create a mock LLM Node."""
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "test-llm-node-id"
|
||||||
|
node.title = "LLM Node"
|
||||||
|
node.execution_id = "test-llm-execution-id"
|
||||||
|
node.node_type = NodeType.LLM
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_node():
|
||||||
|
"""Create a mock Tool Node with tool-specific attributes."""
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
|
||||||
|
node = MagicMock()
|
||||||
|
node.id = "test-tool-node-id"
|
||||||
|
node.title = "Test Tool Node"
|
||||||
|
node.execution_id = "test-tool-execution-id"
|
||||||
|
node.node_type = NodeType.TOOL
|
||||||
|
|
||||||
|
tool_data = ToolNodeData(
|
||||||
|
title="Test Tool Node",
|
||||||
|
desc=None,
|
||||||
|
provider_id="test-provider-id",
|
||||||
|
provider_type=ToolProviderType.BUILT_IN,
|
||||||
|
provider_name="test-provider",
|
||||||
|
tool_name="test-tool",
|
||||||
|
tool_label="Test Tool",
|
||||||
|
tool_configurations={},
|
||||||
|
tool_parameters={},
|
||||||
|
)
|
||||||
|
node._node_data = tool_data
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_is_instrument_flag_enabled_false():
|
||||||
|
"""Mock is_instrument_flag_enabled to return False."""
|
||||||
|
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_is_instrument_flag_enabled_true():
|
||||||
|
"""Mock is_instrument_flag_enabled to return True."""
|
||||||
|
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||||
|
yield
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
Tests for ObservabilityLayer.
|
||||||
|
|
||||||
|
Test coverage:
|
||||||
|
- Initialization and enable/disable logic
|
||||||
|
- Node span lifecycle (start, end, error handling)
|
||||||
|
- Parser integration (default and tool-specific)
|
||||||
|
- Graph lifecycle management
|
||||||
|
- Disabled mode behavior
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from opentelemetry.trace import StatusCode
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.graph_engine.layers.observability import ObservabilityLayer
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservabilityLayerInitialization:
|
||||||
|
"""Test ObservabilityLayer initialization logic."""
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
|
||||||
|
"""Test that layer initializes correctly when OTel is enabled."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
assert not layer._is_disabled
|
||||||
|
assert layer._tracer is not None
|
||||||
|
assert NodeType.TOOL in layer._parsers
|
||||||
|
assert layer._default_parser is not None
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
|
||||||
|
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
|
||||||
|
"""Test that layer enables when instrument flag is enabled."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
assert not layer._is_disabled
|
||||||
|
assert layer._tracer is not None
|
||||||
|
assert NodeType.TOOL in layer._parsers
|
||||||
|
assert layer._default_parser is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservabilityLayerNodeSpanLifecycle:
|
||||||
|
"""Test node span creation and lifecycle management."""
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_node_span_created_and_ended(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||||
|
):
|
||||||
|
"""Test that span is created on node start and ended on node end."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_llm_node)
|
||||||
|
layer.on_node_run_end(mock_llm_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].name == mock_llm_node.title
|
||||||
|
assert spans[0].status.status_code == StatusCode.OK
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_node_error_recorded_in_span(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||||
|
):
|
||||||
|
"""Test that node execution errors are recorded in span."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
error = ValueError("Test error")
|
||||||
|
layer.on_node_run_start(mock_llm_node)
|
||||||
|
layer.on_node_run_end(mock_llm_node, error)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].status.status_code == StatusCode.ERROR
|
||||||
|
assert len(spans[0].events) > 0
|
||||||
|
assert any("exception" in event.name.lower() for event in spans[0].events)
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_node_end_without_start_handled_gracefully(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||||
|
):
|
||||||
|
"""Test that ending a node without start doesn't crash."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_end(mock_llm_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservabilityLayerParserIntegration:
|
||||||
|
"""Test parser integration for different node types."""
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_default_parser_used_for_regular_node(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
|
||||||
|
):
|
||||||
|
"""Test that default parser is used for non-tool nodes."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_start_node)
|
||||||
|
layer.on_node_run_end(mock_start_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 1
|
||||||
|
attrs = spans[0].attributes
|
||||||
|
assert attrs["node.id"] == mock_start_node.id
|
||||||
|
assert attrs["node.execution_id"] == mock_start_node.execution_id
|
||||||
|
assert attrs["node.type"] == mock_start_node.node_type.value
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_tool_parser_used_for_tool_node(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
|
||||||
|
):
|
||||||
|
"""Test that tool parser is used for tool nodes."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_tool_node)
|
||||||
|
layer.on_node_run_end(mock_tool_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 1
|
||||||
|
attrs = spans[0].attributes
|
||||||
|
assert attrs["node.id"] == mock_tool_node.id
|
||||||
|
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
|
||||||
|
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
|
||||||
|
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservabilityLayerGraphLifecycle:
|
||||||
|
"""Test graph lifecycle management."""
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
|
||||||
|
"""Test that on_graph_start clears node contexts."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_llm_node)
|
||||||
|
assert len(layer._node_contexts) == 1
|
||||||
|
|
||||||
|
layer.on_graph_start()
|
||||||
|
assert len(layer._node_contexts) == 0
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_on_graph_end_with_no_unfinished_spans(
|
||||||
|
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
|
||||||
|
):
|
||||||
|
"""Test that on_graph_end handles normal completion."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_llm_node)
|
||||||
|
layer.on_node_run_end(mock_llm_node, None)
|
||||||
|
layer.on_graph_end(None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 1
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_on_graph_end_with_unfinished_spans_logs_warning(
|
||||||
|
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
|
||||||
|
):
|
||||||
|
"""Test that on_graph_end logs warning for unfinished spans."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
layer.on_graph_start()
|
||||||
|
|
||||||
|
layer.on_node_run_start(mock_llm_node)
|
||||||
|
assert len(layer._node_contexts) == 1
|
||||||
|
|
||||||
|
layer.on_graph_end(None)
|
||||||
|
|
||||||
|
assert len(layer._node_contexts) == 0
|
||||||
|
assert "node spans were not properly ended" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
class TestObservabilityLayerDisabledMode:
|
||||||
|
"""Test behavior when layer is disabled."""
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
|
||||||
|
"""Test that disabled layer doesn't create spans on node start."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
assert layer._is_disabled
|
||||||
|
|
||||||
|
layer.on_graph_start()
|
||||||
|
layer.on_node_run_start(mock_start_node)
|
||||||
|
layer.on_node_run_end(mock_start_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 0
|
||||||
|
|
||||||
|
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
|
||||||
|
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
|
||||||
|
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
|
||||||
|
"""Test that disabled layer doesn't process node end."""
|
||||||
|
layer = ObservabilityLayer()
|
||||||
|
assert layer._is_disabled
|
||||||
|
|
||||||
|
layer.on_node_run_end(mock_llm_node, None)
|
||||||
|
|
||||||
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
assert len(spans) == 0
|
||||||
Reference in New Issue
Block a user