feat: add x-trace-id to http responses and logs (#29015)

Introduce trace id to http responses and logs to facilitate debugging process.
This commit is contained in:
wangxiaolei
2025-12-02 17:22:34 +08:00
committed by GitHub
parent f8b10c2272
commit f48522e923
6 changed files with 132 additions and 6 deletions

View File

@@ -1,6 +1,8 @@
import logging
import time
from opentelemetry.trace import get_current_span
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
@@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
return dify_app

View File

@@ -553,7 +553,10 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
)
LOG_DATEFORMAT: str | None = Field(

View File

@@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
def init_app(app: DifyApp):
@@ -25,6 +26,7 @@ def init_app(app: DifyApp):
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(service_api_bp)
@@ -34,7 +36,7 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)
@@ -44,7 +46,7 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(console_app_bp)
@@ -52,6 +54,7 @@ def init_app(app: DifyApp):
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(files_bp)
@@ -63,5 +66,6 @@ def init_app(app: DifyApp):
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(trigger_bp)

View File

@@ -7,6 +7,7 @@ from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
@@ -76,7 +77,9 @@ class RequestIdFilter(logging.Filter):
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
return True
@@ -84,6 +87,8 @@ class RequestIdFormatter(logging.Formatter):
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
return super().format(record)

View File

@@ -1,12 +1,14 @@
import json
import logging
import time
import flask
import werkzeug.http
from flask import Flask
from flask import Flask, g
from flask.signals import request_finished, request_started
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
logger = logging.getLogger(__name__)
@@ -20,6 +22,9 @@ def _is_content_type_json(content_type: str) -> bool:
def _log_request_started(_sender, **_extra):
"""Log the start of a request."""
# Record start time for access logging
g.__request_started_ts = time.perf_counter()
if not logger.isEnabledFor(logging.DEBUG):
return
@@ -42,8 +47,39 @@ def _log_request_started(_sender, **_extra):
def _log_request_finished(_sender, response, **_extra):
"""Log the end of a request."""
if not logger.isEnabledFor(logging.DEBUG) or response is None:
"""Log the end of a request.
Safe to call with or without an active Flask request context.
"""
if response is None:
return
# Always emit a compact access line at INFO with trace_id so it can be grepped
has_ctx = flask.has_request_context()
start_ts = getattr(g, "__request_started_ts", None) if has_ctx else None
duration_ms = None
if start_ts is not None:
duration_ms = round((time.perf_counter() - start_ts) * 1000, 3)
# Request attributes are available only when a request context exists
if has_ctx:
req_method = flask.request.method
req_path = flask.request.path
else:
req_method = "-"
req_path = "-"
trace_id = get_trace_id_from_otel_context() or response.headers.get("X-Trace-Id") or ""
logger.info(
"%s %s %s %s %s",
req_method,
req_path,
getattr(response, "status_code", "-"),
duration_ms if duration_ms is not None else "-",
trace_id,
)
if not logger.isEnabledFor(logging.DEBUG):
return
if not _is_content_type_json(response.content_type):

View File

@@ -263,3 +263,62 @@ class TestResponseUnmodified:
)
assert response.text == _RESPONSE_NEEDLE
assert response.status_code == 200
class TestRequestFinishedInfoAccessLine:
def test_info_access_log_includes_method_path_status_duration_trace_id(self, monkeypatch, caplog):
"""Ensure INFO access line contains expected fields with computed duration and trace id."""
app = _get_test_app()
# Push a real request context so flask.request and g are available
with app.test_request_context("/foo", method="GET"):
# Seed start timestamp via the extension's own start hook and control perf_counter deterministically
seq = iter([100.0, 100.123456])
monkeypatch.setattr(ext_request_logging.time, "perf_counter", lambda: next(seq))
# Provide a deterministic trace id
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "trace-xyz",
)
# Simulate request_started to record start timestamp on g
ext_request_logging._log_request_started(app)
# Capture logs from the real logger at INFO level only (skip DEBUG branch)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response(json.dumps({"ok": True}), mimetype="application/json", status=200)
_log_request_finished(app, response)
# Verify a single INFO record with the five fields in order
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
# Expected format: METHOD PATH STATUS DURATION_MS TRACE_ID
assert "GET" in msg
assert "/foo" in msg
assert "200" in msg
assert "123.456" in msg # rounded to 3 decimals
assert "trace-xyz" in msg
def test_info_access_log_uses_dash_without_start_timestamp(self, monkeypatch, caplog):
app = _get_test_app()
with app.test_request_context("/bar", method="POST"):
# No g.__request_started_ts set -> duration should be '-'
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "tid-no-start",
)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response("OK", mimetype="text/plain", status=204)
_log_request_finished(app, response)
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
assert "POST" in msg
assert "/bar" in msg
assert "204" in msg
# Duration placeholder
# The fields are space separated; ensure a standalone '-' appears
assert " - " in msg or msg.endswith(" -")
assert "tid-no-start" in msg