diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeae..ad83826427 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a6..b0d9fa7af6 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f..df5142c310 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c..6b2112ceb2 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py index f7475f2239..12e91f190f 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py @@ -39,7 +39,7 @@ class TestAppGenerateHandler: "root_node_id": None, } - arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) + arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs) assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" assert "app_model" in arguments, "Handler uses app_model but parameter is missing" @@ -70,14 +70,11 @@ class TestAppGenerateHandler: handler.wrapper( tracer, dummy_func, - (), - { - "app_model": mock_app_model, - "user": mock_account_user, - "args": {"workflow_id": test_workflow_id}, - "invoke_from": InvokeFrom.DEBUGGER, - "streaming": False, - }, + app_model=mock_app_model, + user=mock_account_user, + args={"workflow_id": test_workflow_id}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=False, ) spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py index 500f80fc3c..842e7f55e2 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py @@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler: def runner_run(self): return "result" - handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) + handler.wrapper(tracer, runner_run, mock_workflow_runner) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py index 44788bab9a..bf861e3ef7 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py @@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments: args = (1, 2, 3) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments: args = () kwargs = {"a": 1, "b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {"b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments: instance = MyClass() args = (1, 2) kwargs = {} - result = handler._extract_arguments(instance.method, args, kwargs) + result = handler._extract_arguments(instance.method, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is None @@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments: assert func not in handler._signature_cache - handler._extract_arguments(func, (1, 2), {}) + handler._extract_arguments(func, 1, 2) assert func in handler._signature_cache cached_sig = handler._signature_cache[func] - handler._extract_arguments(func, (3, 4), {}) + handler._extract_arguments(func, 3, 4) assert handler._signature_cache[func] is cached_sig @@ -142,7 +142,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - result = handler.wrapper(tracer, test_func, (), {}) + result = handler.wrapper(tracer, test_func) assert result == "result" spans = memory_span_exporter.get_finished_spans() @@ -159,7 +159,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -174,7 +174,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -190,7 +190,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -208,7 +208,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -225,7 +225,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): @@ -236,7 +236,7 @@ class TestSpanHandlerWrapper: def test_func(a, b, c=10): return a + b + c - result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) + result = handler.wrapper(tracer, test_func, 1, 2, c=3) assert result == 6 @@ -249,7 +249,7 @@ class TestSpanHandlerWrapper: def my_function(x): return x * 2 - result = handler.wrapper(tracer, my_function, (5,), {}) + result = handler.wrapper(tracer, my_function, 5) assert result == 10 spans = memory_span_exporter.get_finished_spans()