ragflow/agent/trace/trace_formatter.py
0xsatoshi99 697f8138b6 feat: add trace logging for agent completions API (Issue #10081)
Implement comprehensive trace logging system for agent execution that
returns step-by-step execution traces in API responses.

New modules:
- agent/trace/trace_models.py: Data models for trace events, sessions,
  LLM calls, retrievals, and tool calls
- agent/trace/trace_collector.py: Real-time trace event collection with
  subscriber pattern for streaming
- agent/trace/trace_formatter.py: Multiple formatters (streaming, compact,
  detailed) for different output needs
- api/db/services/trace_service.py: Service layer for trace persistence,
  retrieval, and analysis
- api/apps/trace_app.py: REST API endpoints for trace management

Features:
- Real-time trace streaming via SSE
- Multiple trace verbosity levels (minimal, standard, detailed, debug)
- Component execution timing and bottleneck detection
- LLM call tracking with token counts
- Retrieval operation logging with chunk details
- Tool call tracing with arguments and results
- Trace session persistence in Redis
- Analysis and recommendations based on trace data

API Endpoints:
- GET /traces - List trace sessions
- GET /traces/<task_id> - Get trace session
- GET /traces/<task_id>/events - Get filtered events
- GET /traces/<task_id>/summary - Get trace summary
- GET /traces/<task_id>/analysis - Analyze trace
- GET /traces/<task_id>/stream - Stream trace events
- DELETE /traces/<task_id> - Delete trace
- POST /traces/cleanup - Cleanup old traces
- POST /agents/<agent_id>/completions/trace - Completion with trace

Closes #10081
2025-12-03 18:08:14 +01:00

416 lines
16 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Formatter for Agent Execution Logs
This module provides various formatters for converting trace events and sessions
into different output formats suitable for API responses, logging, and debugging.
"""
import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Optional, Generator
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceEvent,
TraceSession,
LLMCallTrace,
RetrievalTrace,
ToolCallTrace,
)
class TraceFormatter(ABC):
"""Abstract base class for trace formatters."""
@abstractmethod
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a single trace event."""
pass
@abstractmethod
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a complete trace session."""
pass
@abstractmethod
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming."""
pass
class StreamingTraceFormatter(TraceFormatter):
"""Formatter optimized for real-time SSE streaming."""
def __init__(self, include_inputs: bool = True, include_outputs: bool = True):
"""Initialize the streaming formatter."""
self.include_inputs = include_inputs
self.include_outputs = include_outputs
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event for streaming."""
result = {
"event_id": event.event_id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
}
if event.component_id:
result["component_id"] = event.component_id
if event.component_name:
result["component_name"] = event.component_name
if event.component_type:
result["component_type"] = event.component_type
if self.include_inputs and event.inputs:
result["inputs"] = self._truncate_dict(event.inputs, max_length=200)
if self.include_outputs and event.outputs:
result["outputs"] = self._truncate_dict(event.outputs, max_length=500)
if event.error:
result["error"] = event.error
if event.elapsed_time is not None:
result["elapsed_time"] = round(event.elapsed_time, 3)
if event.thoughts:
result["thoughts"] = event.thoughts[:300] if len(event.thoughts) > 300 else event.thoughts
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a trace session for streaming response."""
return {
"session_id": session.session_id,
"status": session.status,
"started_at": session.started_at.isoformat(),
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"total_elapsed_time": round(session.total_elapsed_time, 3),
"summary": session.get_summary(),
"events": [self.format_event(e) for e in session.events],
}
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event as SSE data."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace', 'data': data}, ensure_ascii=False)}\n\n"
def _truncate_dict(self, d: dict[str, Any], max_length: int = 200) -> dict[str, Any]:
"""Truncate string values in a dictionary."""
result = {}
for key, value in d.items():
if isinstance(value, str) and len(value) > max_length:
result[key] = value[:max_length] + "..."
elif isinstance(value, dict):
result[key] = self._truncate_dict(value, max_length)
elif isinstance(value, list) and len(value) > 10:
result[key] = value[:10] + ["..."]
else:
result[key] = value
return result
class CompactTraceFormatter(TraceFormatter):
"""Compact formatter for minimal trace output."""
def __init__(self):
"""Initialize the compact formatter."""
self._event_icons = {
TraceEventType.WORKFLOW_STARTED: "🚀",
TraceEventType.WORKFLOW_COMPLETED: "",
TraceEventType.WORKFLOW_FAILED: "",
TraceEventType.NODE_STARTED: "▶️",
TraceEventType.NODE_FINISHED: "✔️",
TraceEventType.NODE_FAILED: "",
TraceEventType.RETRIEVAL_STARTED: "🔍",
TraceEventType.RETRIEVAL_COMPLETED: "📚",
TraceEventType.LLM_CALL_STARTED: "🤖",
TraceEventType.LLM_CALL_COMPLETED: "💬",
TraceEventType.TOOL_CALL_STARTED: "🔧",
TraceEventType.TOOL_CALL_COMPLETED: "⚙️",
TraceEventType.MESSAGE_GENERATED: "📝",
TraceEventType.ERROR_OCCURRED: "⚠️",
TraceEventType.THINKING_STARTED: "💭",
TraceEventType.THINKING_COMPLETED: "💡",
}
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event in compact form."""
icon = self._event_icons.get(event.event_type, "")
result = {
"type": event.event_type.value,
"icon": icon,
"time": event.timestamp.strftime("%H:%M:%S.%f")[:-3],
}
if event.component_name:
result["component"] = event.component_name
if event.elapsed_time is not None:
result["duration_ms"] = round(event.elapsed_time * 1000, 1)
if event.error:
result["error"] = event.error[:100]
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a trace session in compact form."""
summary = session.get_summary()
return {
"session_id": session.session_id,
"status": session.status,
"duration_s": round(session.total_elapsed_time, 2),
"nodes": summary["nodes_executed"],
"llm_calls": summary["total_llm_calls"],
"retrievals": summary["total_retrievals"],
"tool_calls": summary["total_tool_calls"],
"tokens": summary["total_tokens"],
"errors": summary["errors_count"],
"timeline": [self.format_event(e) for e in session.events],
}
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming in compact form."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace_compact', 'data': data}, ensure_ascii=False)}\n\n"
def format_timeline(self, session: TraceSession) -> list[str]:
"""Format session as a text timeline."""
lines = []
for event in session.events:
icon = self._event_icons.get(event.event_type, "")
time_str = event.timestamp.strftime("%H:%M:%S")
component = event.component_name or ""
duration = f" ({event.elapsed_time*1000:.0f}ms)" if event.elapsed_time else ""
line = f"{time_str} {icon} {event.event_type.value}"
if component:
line += f" [{component}]"
line += duration
if event.error:
line += f" ERROR: {event.error[:50]}"
lines.append(line)
return lines
class DetailedTraceFormatter(TraceFormatter):
"""Detailed formatter for comprehensive trace output."""
def __init__(self, include_raw_data: bool = False):
"""Initialize the detailed formatter."""
self.include_raw_data = include_raw_data
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event with full details."""
result = {
"event_id": event.event_id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
"timestamp_unix": event.timestamp.timestamp(),
}
if event.component_id:
result["component"] = {
"id": event.component_id,
"name": event.component_name,
"type": event.component_type,
}
if event.inputs is not None:
result["inputs"] = event.inputs
if event.outputs is not None:
result["outputs"] = event.outputs
if event.error:
result["error"] = {
"message": event.error,
"occurred_at": event.timestamp.isoformat(),
}
if event.elapsed_time is not None:
result["timing"] = {
"elapsed_seconds": round(event.elapsed_time, 4),
"elapsed_ms": round(event.elapsed_time * 1000, 2),
}
if event.thoughts:
result["thoughts"] = event.thoughts
if event.metadata:
result["metadata"] = event.metadata
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a complete trace session with all details."""
result = {
"session_id": session.session_id,
"metadata": session.metadata.to_dict(),
"status": session.status,
"timing": {
"started_at": session.started_at.isoformat(),
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"total_elapsed_seconds": round(session.total_elapsed_time, 4),
},
"summary": session.get_summary(),
"events": [self.format_event(e) for e in session.events],
"llm_calls": [self._format_llm_call(c) for c in session.llm_calls],
"retrievals": [self._format_retrieval(r) for r in session.retrievals],
"tool_calls": [self._format_tool_call(t) for t in session.tool_calls],
}
if session.error:
result["error"] = session.error
return result
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming with full details."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace_detailed', 'data': data}, ensure_ascii=False)}\n\n"
def _format_llm_call(self, call: LLMCallTrace) -> dict[str, Any]:
"""Format an LLM call trace."""
result = {
"call_id": call.call_id,
"model_name": call.model_name,
"tokens": {
"prompt": call.prompt_tokens,
"completion": call.completion_tokens,
"total": call.total_tokens,
},
"latency_ms": round(call.latency_ms, 2),
"temperature": call.temperature,
"started_at": call.started_at.isoformat(),
"completed_at": call.completed_at.isoformat() if call.completed_at else None,
}
if self.include_raw_data:
result["prompt"] = call.prompt
result["response"] = call.response
else:
result["prompt_preview"] = call.prompt[:200] + "..." if len(call.prompt) > 200 else call.prompt
result["response_preview"] = call.response[:200] + "..." if call.response and len(call.response) > 200 else call.response
if call.max_tokens:
result["max_tokens"] = call.max_tokens
if call.error:
result["error"] = call.error
return result
def _format_retrieval(self, retrieval: RetrievalTrace) -> dict[str, Any]:
"""Format a retrieval trace."""
result = {
"retrieval_id": retrieval.retrieval_id,
"query": retrieval.query,
"knowledge_bases": retrieval.knowledge_bases,
"config": {
"top_k": retrieval.top_k,
"similarity_threshold": retrieval.similarity_threshold,
"rerank_enabled": retrieval.rerank_enabled,
},
"results": {
"chunks_retrieved": retrieval.chunks_retrieved,
"chunks_preview": retrieval.chunks[:3] if self.include_raw_data else [
{"id": c.get("id"), "score": c.get("score")} for c in retrieval.chunks[:3]
],
},
"latency_ms": round(retrieval.latency_ms, 2),
"started_at": retrieval.started_at.isoformat(),
"completed_at": retrieval.completed_at.isoformat() if retrieval.completed_at else None,
}
if retrieval.error:
result["error"] = retrieval.error
return result
def _format_tool_call(self, tool: ToolCallTrace) -> dict[str, Any]:
"""Format a tool call trace."""
result = {
"call_id": tool.call_id,
"tool_name": tool.tool_name,
"tool_type": tool.tool_type,
"arguments": tool.arguments,
"latency_ms": round(tool.latency_ms, 2),
"started_at": tool.started_at.isoformat(),
"completed_at": tool.completed_at.isoformat() if tool.completed_at else None,
}
if self.include_raw_data:
result["result"] = tool.result
else:
result_str = str(tool.result) if tool.result else None
result["result_preview"] = result_str[:200] + "..." if result_str and len(result_str) > 200 else result_str
if tool.error:
result["error"] = tool.error
return result
class TraceFormatterFactory:
"""Factory for creating trace formatters."""
_formatters = {
"streaming": StreamingTraceFormatter,
"compact": CompactTraceFormatter,
"detailed": DetailedTraceFormatter,
}
@classmethod
def create(cls, format_type: str = "streaming", **kwargs) -> TraceFormatter:
"""Create a trace formatter by type."""
formatter_class = cls._formatters.get(format_type)
if not formatter_class:
raise ValueError(f"Unknown formatter type: {format_type}. Available: {list(cls._formatters.keys())}")
return formatter_class(**kwargs)
@classmethod
def register(cls, name: str, formatter_class: type) -> None:
"""Register a custom formatter."""
if not issubclass(formatter_class, TraceFormatter):
raise TypeError("Formatter must be a subclass of TraceFormatter")
cls._formatters[name] = formatter_class
@classmethod
def available_formatters(cls) -> list[str]:
"""Get list of available formatter types."""
return list(cls._formatters.keys())
def format_trace_for_api(
session: TraceSession,
format_type: str = "streaming",
**kwargs
) -> dict[str, Any]:
"""Convenience function to format a trace session for API response."""
formatter = TraceFormatterFactory.create(format_type, **kwargs)
return formatter.format_session(session)
def generate_trace_stream(
events: Generator[TraceEvent, None, None],
format_type: str = "streaming",
**kwargs
) -> Generator[str, None, None]:
"""Generate SSE stream from trace events."""
formatter = TraceFormatterFactory.create(format_type, **kwargs)
for event in events:
yield formatter.format_for_stream(event)