Files
dify/api/tests/unit_tests/tasks/test_ops_trace_task.py
Blackoutta 7fc40e6c9e feat: improve phoenix workflow tracing (#35605)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-11 08:37:17 +00:00

302 lines
12 KiB
Python

import json
import sys
from contextlib import contextmanager
from types import ModuleType
from unittest.mock import MagicMock, patch
import pytest
from celery.exceptions import Retry
from core.ops.entities.config_entity import OPS_TRACE_FAILED_KEY
from core.ops.exceptions import RetryableTraceDispatchError
from tasks.ops_trace_task import process_trace_tasks
@contextmanager
def fake_app_context():
yield
class FakeCurrentApp:
def app_context(self):
return fake_app_context()
def _install_trace_manager(
trace_instance: MagicMock,
*,
enterprise_enabled: bool = False,
enterprise_trace_cls: MagicMock | None = None,
) -> dict[str, ModuleType]:
ops_trace_manager_module = ModuleType("core.ops.ops_trace_manager")
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id: str) -> MagicMock:
return trace_instance
telemetry_module = ModuleType("extensions.ext_enterprise_telemetry")
telemetry_module.is_enabled = lambda: enterprise_enabled
ops_trace_manager_module.OpsTraceManager = StubOpsTraceManager
modules = {
"core.ops.ops_trace_manager": ops_trace_manager_module,
"extensions.ext_enterprise_telemetry": telemetry_module,
}
if enterprise_trace_cls is not None:
enterprise_module = ModuleType("enterprise")
enterprise_telemetry_module = ModuleType("enterprise.telemetry")
enterprise_trace_module = ModuleType("enterprise.telemetry.enterprise_trace")
enterprise_trace_module.EnterpriseOtelTrace = enterprise_trace_cls
modules.update(
{
"enterprise": enterprise_module,
"enterprise.telemetry": enterprise_telemetry_module,
"enterprise.telemetry.enterprise_trace": enterprise_trace_module,
}
)
return modules
def _make_payload() -> str:
return json.dumps({"trace_info": {}, "trace_info_type": None})
def _decode_saved_payload(payload: bytes | str) -> dict[str, object]:
if isinstance(payload, bytes):
payload = payload.decode("utf-8")
return json.loads(payload)
def _retryable_dispatch_error() -> RetryableTraceDispatchError:
return RetryableTraceDispatchError("transient trace dispatch failure")
def _run_task(file_info: dict[str, str], retries: int = 0) -> None:
process_trace_tasks.push_request(retries=retries)
try:
process_trace_tasks.run(file_info)
finally:
process_trace_tasks.pop_request()
def test_process_trace_tasks_retries_retryable_dispatch_failure_and_preserves_payload():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_marks_enterprise_trace_dispatched_before_retryable_dispatch_retry():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
enterprise_tracer = MagicMock()
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
enterprise_tracer.trace.assert_called_once_with({})
saved_path, saved_payload = mock_save.call_args.args
assert saved_path == "ops_trace/app-id/file-id.json"
assert _decode_saved_payload(saved_payload)["_enterprise_trace_dispatched"] is True
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_does_not_mark_failed_enterprise_trace_as_dispatched_before_retry():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
retry_error = Retry()
enterprise_tracer = MagicMock()
enterprise_tracer.trace.side_effect = RuntimeError("enterprise trace failed")
enterprise_trace_cls = MagicMock(return_value=enterprise_tracer)
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry,
pytest.raises(Retry),
):
_run_task(file_info)
enterprise_tracer.trace.assert_called_once_with({})
mock_save.assert_not_called()
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_not_called()
mock_incr.assert_not_called()
def test_process_trace_tasks_skips_enterprise_trace_when_retry_payload_was_already_dispatched():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
enterprise_trace_cls = MagicMock()
payload = json.dumps({"trace_info": {}, "trace_info_type": None, "_enterprise_trace_dispatched": True})
with (
patch.dict(
sys.modules,
_install_trace_manager(
trace_instance,
enterprise_enabled=True,
enterprise_trace_cls=enterprise_trace_cls,
),
),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=payload),
patch("tasks.ops_trace_task.storage.save") as mock_save,
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
enterprise_trace_cls.assert_not_called()
trace_instance.trace.assert_called_once_with({})
mock_save.assert_not_called()
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_not_called()
def test_process_trace_tasks_default_retry_window_covers_parent_span_context_ttl():
assert process_trace_tasks.max_retries * process_trace_tasks.default_retry_delay >= 300
def test_process_trace_tasks_deletes_payload_on_success():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
trace_instance.trace.assert_called_once_with({})
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_not_called()
def test_process_trace_tasks_deletes_payload_and_counts_terminal_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
trace_instance.trace.side_effect = RuntimeError("trace failed")
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
):
_run_task(file_info)
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
def test_process_trace_tasks_treats_retry_enqueue_failure_as_terminal_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
retry_enqueue_error = RuntimeError("retry enqueue failed")
trace_instance.trace.side_effect = pending_error
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry", side_effect=retry_enqueue_error) as mock_retry,
):
_run_task(file_info)
mock_retry.assert_called_once_with(
exc=pending_error,
countdown=process_trace_tasks.default_retry_delay,
)
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")
def test_process_trace_tasks_deletes_payload_and_counts_exhausted_retryable_dispatch_failure():
file_info = {"app_id": "app-id", "file_id": "file-id"}
trace_instance = MagicMock()
pending_error = _retryable_dispatch_error()
trace_instance.trace.side_effect = pending_error
with (
patch.dict(sys.modules, _install_trace_manager(trace_instance)),
patch("tasks.ops_trace_task.current_app", FakeCurrentApp()),
patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()),
patch("tasks.ops_trace_task.storage.delete") as mock_delete,
patch("tasks.ops_trace_task.redis_client.incr") as mock_incr,
patch.object(process_trace_tasks, "retry") as mock_retry,
):
_run_task(file_info, retries=process_trace_tasks.max_retries)
mock_retry.assert_not_called()
mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json")
mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id")