mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
fix: workflow token usage (#26723)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
@@ -125,6 +126,7 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
@@ -163,6 +165,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@@ -212,6 +216,8 @@ class EventHandler:
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
@@ -235,6 +241,8 @@ class EventHandler:
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
@@ -286,6 +294,19 @@ class EventHandler:
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||
"""Accumulate token usage into the shared runtime state."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
current_usage = self._graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self._graph_runtime_state.llm_usage = usage
|
||||
else:
|
||||
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Reference in New Issue
Block a user