refactor(otel): replace Any with Tracer and [T] generics (#34883)

This commit is contained in:
corevibe555
2026-04-10 10:37:14 +03:00
committed by GitHub
parent bcd738d2e6
commit af55665ff2
7 changed files with 56 additions and 66 deletions

View File

@@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
handler = _get_handler_instance(handler_class or SpanHandler)
tracer = get_tracer(__name__)
return handler.wrapper(
tracer=tracer,
wrapped=func,
args=args,
kwargs=kwargs,
)
return handler.wrapper(tracer, func, *args, **kwargs)
return cast(Callable[P, R], wrapper)

View File

@@ -1,8 +1,8 @@
import inspect
from collections.abc import Callable, Mapping
from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
class SpanHandler:
@@ -16,9 +16,9 @@ class SpanHandler:
exceptions. Handlers can override the wrapper method to customize behavior.
"""
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
"""
Build the span name from the wrapped function.
@@ -29,11 +29,11 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments[T](
def _extract_arguments[**P, R](
self,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> dict[str, Any] | None:
"""
Extract function arguments using inspect.signature.
@@ -59,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""
Fully control the wrapper behavior.

View File

@@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``."""
def wrapper(
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@@ -39,7 +39,7 @@ class TestAppGenerateHandler:
"root_node_id": None,
}
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
@@ -70,14 +70,11 @@ class TestAppGenerateHandler:
handler.wrapper(
tracer,
dummy_func,
(),
{
"app_model": mock_app_model,
"user": mock_account_user,
"args": {"workflow_id": test_workflow_id},
"invoke_from": InvokeFrom.DEBUGGER,
"streaming": False,
},
app_model=mock_app_model,
user=mock_account_user,
args={"workflow_id": test_workflow_id},
invoke_from=InvokeFrom.DEBUGGER,
streaming=False,
)
spans = memory_span_exporter.get_finished_spans()

View File

@@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
def runner_run(self):
return "result"
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
handler.wrapper(tracer, runner_run, mock_workflow_runner)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1

View File

@@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
args = (1, 2, 3)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
args = ()
kwargs = {"a": 1, "b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {"b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
instance = MyClass()
args = (1, 2)
kwargs = {}
result = handler._extract_arguments(instance.method, args, kwargs)
result = handler._extract_arguments(instance.method, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is None
@@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
assert func not in handler._signature_cache
handler._extract_arguments(func, (1, 2), {})
handler._extract_arguments(func, 1, 2)
assert func in handler._signature_cache
cached_sig = handler._signature_cache[func]
handler._extract_arguments(func, (3, 4), {})
handler._extract_arguments(func, 3, 4)
assert handler._signature_cache[func] is cached_sig
@@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
result = handler.wrapper(tracer, test_func, (), {})
result = handler.wrapper(tracer, test_func)
assert result == "result"
spans = memory_span_exporter.get_finished_spans()
@@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
@@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
def test_func(a, b, c=10):
return a + b + c
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
assert result == 6
@@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
def my_function(x):
return x * 2
result = handler.wrapper(tracer, my_function, (5,), {})
result = handler.wrapper(tracer, my_function, 5)
assert result == 10
spans = memory_span_exporter.get_finished_spans()