feat: add trace logging for agent completions API (Issue #10081)

Implement comprehensive trace logging system for agent execution that
returns step-by-step execution traces in API responses.

New modules:
- agent/trace/trace_models.py: Data models for trace events, sessions,
  LLM calls, retrievals, and tool calls
- agent/trace/trace_collector.py: Real-time trace event collection with
  subscriber pattern for streaming
- agent/trace/trace_formatter.py: Multiple formatters (streaming, compact,
  detailed) for different output needs
- api/db/services/trace_service.py: Service layer for trace persistence,
  retrieval, and analysis
- api/apps/trace_app.py: REST API endpoints for trace management

Features:
- Real-time trace streaming via SSE
- Multiple trace verbosity levels (minimal, standard, detailed, debug)
- Component execution timing and bottleneck detection
- LLM call tracking with token counts
- Retrieval operation logging with chunk details
- Tool call tracing with arguments and results
- Trace session persistence in Redis
- Analysis and recommendations based on trace data

API Endpoints:
- GET /traces - List trace sessions
- GET /traces/<task_id> - Get trace session
- GET /traces/<task_id>/events - Get filtered events
- GET /traces/<task_id>/summary - Get trace summary
- GET /traces/<task_id>/analysis - Analyze trace
- GET /traces/<task_id>/stream - Stream trace events
- DELETE /traces/<task_id> - Delete trace
- POST /traces/cleanup - Cleanup old traces
- POST /agents/<agent_id>/completions/trace - Completion with trace

Closes #10081
This commit is contained in:
0xsatoshi99 2025-12-03 18:08:14 +01:00
parent 3c224c817b
commit 697f8138b6
6 changed files with 2327 additions and 0 deletions

67
agent/trace/__init__.py Normal file
View file

@ -0,0 +1,67 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Agent Trace Module
This module provides comprehensive tracing capabilities for agent execution,
including trace models, collectors, formatters, and services.
"""
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceMetadata,
ComponentInfo,
TraceEvent,
LLMCallTrace,
RetrievalTrace,
ToolCallTrace,
TraceSession,
)
from agent.trace.trace_collector import (
TraceCollector,
get_trace_collector,
create_trace_collector,
)
from agent.trace.trace_formatter import (
TraceFormatter,
StreamingTraceFormatter,
CompactTraceFormatter,
DetailedTraceFormatter,
)
__all__ = [
# Models
"TraceEventType",
"TraceLevel",
"TraceMetadata",
"ComponentInfo",
"TraceEvent",
"LLMCallTrace",
"RetrievalTrace",
"ToolCallTrace",
"TraceSession",
# Collector
"TraceCollector",
"get_trace_collector",
"create_trace_collector",
# Formatters
"TraceFormatter",
"StreamingTraceFormatter",
"CompactTraceFormatter",
"DetailedTraceFormatter",
]

View file

@ -0,0 +1,510 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Collector for Agent Execution
This module provides the TraceCollector class that captures and manages
trace events during agent workflow execution. It supports real-time
event streaming and session management.
"""
import asyncio
import logging
import threading
import time
import uuid
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Callable, Optional, Generator, AsyncGenerator
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceMetadata,
TraceEvent,
TraceSession,
LLMCallTrace,
RetrievalTrace,
ToolCallTrace,
ComponentInfo,
)
_trace_collectors: dict[str, "TraceCollector"] = {}
_collectors_lock = threading.Lock()
def get_trace_collector(task_id: str) -> Optional["TraceCollector"]:
"""Get an existing trace collector by task ID."""
with _collectors_lock:
return _trace_collectors.get(task_id)
def create_trace_collector(
task_id: str,
agent_id: str,
session_id: str,
user_id: str,
tenant_id: str,
trace_level: TraceLevel = TraceLevel.STANDARD,
) -> "TraceCollector":
"""Create a new trace collector for an agent execution."""
with _collectors_lock:
if task_id in _trace_collectors:
return _trace_collectors[task_id]
collector = TraceCollector(
task_id=task_id,
agent_id=agent_id,
session_id=session_id,
user_id=user_id,
tenant_id=tenant_id,
trace_level=trace_level,
)
_trace_collectors[task_id] = collector
return collector
def remove_trace_collector(task_id: str) -> None:
"""Remove a trace collector from the registry."""
with _collectors_lock:
if task_id in _trace_collectors:
del _trace_collectors[task_id]
class TraceCollector:
"""
Collects and manages trace events during agent execution.
This class provides methods to record various types of trace events,
manage trace sessions, and stream events in real-time.
"""
def __init__(
self,
task_id: str,
agent_id: str,
session_id: str,
user_id: str,
tenant_id: str,
trace_level: TraceLevel = TraceLevel.STANDARD,
):
"""Initialize the trace collector."""
self.task_id = task_id
self.trace_level = trace_level
self._lock = threading.Lock()
self._event_queue: asyncio.Queue[TraceEvent] = asyncio.Queue()
self._subscribers: list[Callable[[TraceEvent], None]] = []
self._is_active = True
metadata = TraceMetadata(
agent_id=agent_id,
session_id=session_id,
user_id=user_id,
tenant_id=tenant_id,
trace_level=trace_level,
)
self.session = TraceSession(
session_id=session_id,
metadata=metadata,
)
self._component_start_times: dict[str, float] = {}
self._llm_call_starts: dict[str, tuple[float, LLMCallTrace]] = {}
self._retrieval_starts: dict[str, tuple[float, RetrievalTrace]] = {}
self._tool_call_starts: dict[str, tuple[float, ToolCallTrace]] = {}
def _should_trace(self, required_level: TraceLevel) -> bool:
"""Check if the current trace level allows this event."""
level_order = [TraceLevel.MINIMAL, TraceLevel.STANDARD, TraceLevel.DETAILED, TraceLevel.DEBUG]
current_idx = level_order.index(self.trace_level)
required_idx = level_order.index(required_level)
return current_idx >= required_idx
def _create_event(
self,
event_type: TraceEventType,
component_id: Optional[str] = None,
component_name: Optional[str] = None,
component_type: Optional[str] = None,
inputs: Optional[dict[str, Any]] = None,
outputs: Optional[dict[str, Any]] = None,
error: Optional[str] = None,
elapsed_time: Optional[float] = None,
thoughts: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
) -> TraceEvent:
"""Create a new trace event."""
return TraceEvent(
event_id=str(uuid.uuid4()),
event_type=event_type,
timestamp=datetime.utcnow(),
component_id=component_id,
component_name=component_name,
component_type=component_type,
inputs=inputs,
outputs=outputs,
error=error,
elapsed_time=elapsed_time,
thoughts=thoughts,
metadata=metadata or {},
)
def _record_event(self, event: TraceEvent) -> None:
"""Record an event and notify subscribers."""
with self._lock:
self.session.add_event(event)
for subscriber in self._subscribers:
try:
subscriber(event)
except Exception as e:
logging.warning(f"Trace subscriber error: {e}")
def subscribe(self, callback: Callable[[TraceEvent], None]) -> None:
"""Subscribe to trace events."""
with self._lock:
self._subscribers.append(callback)
def unsubscribe(self, callback: Callable[[TraceEvent], None]) -> None:
"""Unsubscribe from trace events."""
with self._lock:
if callback in self._subscribers:
self._subscribers.remove(callback)
def workflow_started(self, inputs: Optional[dict[str, Any]] = None) -> None:
"""Record workflow start event."""
event = self._create_event(
event_type=TraceEventType.WORKFLOW_STARTED,
inputs=inputs,
)
self._record_event(event)
def workflow_completed(self, outputs: Optional[dict[str, Any]] = None) -> None:
"""Record workflow completion event."""
event = self._create_event(
event_type=TraceEventType.WORKFLOW_COMPLETED,
outputs=outputs,
elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(),
)
self._record_event(event)
self.session.complete()
def workflow_failed(self, error: str) -> None:
"""Record workflow failure event."""
event = self._create_event(
event_type=TraceEventType.WORKFLOW_FAILED,
error=error,
elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(),
)
self._record_event(event)
self.session.complete(error=error)
def node_started(
self,
component_id: str,
component_name: str,
component_type: str,
inputs: Optional[dict[str, Any]] = None,
thoughts: Optional[str] = None,
) -> None:
"""Record node start event."""
self._component_start_times[component_id] = time.perf_counter()
event = self._create_event(
event_type=TraceEventType.NODE_STARTED,
component_id=component_id,
component_name=component_name,
component_type=component_type,
inputs=inputs,
thoughts=thoughts,
)
self._record_event(event)
def node_finished(
self,
component_id: str,
component_name: str,
component_type: str,
inputs: Optional[dict[str, Any]] = None,
outputs: Optional[dict[str, Any]] = None,
error: Optional[str] = None,
) -> None:
"""Record node completion event."""
start_time = self._component_start_times.pop(component_id, None)
elapsed_time = time.perf_counter() - start_time if start_time else None
event_type = TraceEventType.NODE_FAILED if error else TraceEventType.NODE_FINISHED
event = self._create_event(
event_type=event_type,
component_id=component_id,
component_name=component_name,
component_type=component_type,
inputs=inputs,
outputs=outputs,
error=error,
elapsed_time=elapsed_time,
)
self._record_event(event)
def llm_call_started(
self,
call_id: str,
model_name: str,
prompt: str,
temperature: float = 0.0,
max_tokens: Optional[int] = None,
) -> None:
"""Record LLM call start."""
if not self._should_trace(TraceLevel.DETAILED):
return
llm_trace = LLMCallTrace(
call_id=call_id,
model_name=model_name,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
)
self._llm_call_starts[call_id] = (time.perf_counter(), llm_trace)
event = self._create_event(
event_type=TraceEventType.LLM_CALL_STARTED,
metadata={"call_id": call_id, "model_name": model_name},
)
self._record_event(event)
def llm_call_completed(
self,
call_id: str,
response: str,
prompt_tokens: int = 0,
completion_tokens: int = 0,
error: Optional[str] = None,
) -> None:
"""Record LLM call completion."""
if not self._should_trace(TraceLevel.DETAILED):
return
start_data = self._llm_call_starts.pop(call_id, None)
if start_data:
start_time, llm_trace = start_data
llm_trace.response = response
llm_trace.prompt_tokens = prompt_tokens
llm_trace.completion_tokens = completion_tokens
llm_trace.total_tokens = prompt_tokens + completion_tokens
llm_trace.latency_ms = (time.perf_counter() - start_time) * 1000
llm_trace.completed_at = datetime.utcnow()
llm_trace.error = error
self.session.add_llm_call(llm_trace)
event = self._create_event(
event_type=TraceEventType.LLM_CALL_COMPLETED,
error=error,
metadata={"call_id": call_id, "tokens": prompt_tokens + completion_tokens},
)
self._record_event(event)
def retrieval_started(
self,
retrieval_id: str,
query: str,
knowledge_bases: list[str],
top_k: int = 10,
similarity_threshold: float = 0.0,
rerank_enabled: bool = False,
) -> None:
"""Record retrieval operation start."""
retrieval_trace = RetrievalTrace(
retrieval_id=retrieval_id,
query=query,
knowledge_bases=knowledge_bases,
top_k=top_k,
similarity_threshold=similarity_threshold,
rerank_enabled=rerank_enabled,
)
self._retrieval_starts[retrieval_id] = (time.perf_counter(), retrieval_trace)
event = self._create_event(
event_type=TraceEventType.RETRIEVAL_STARTED,
metadata={"retrieval_id": retrieval_id, "query": query[:100]},
)
self._record_event(event)
def retrieval_completed(
self,
retrieval_id: str,
chunks: list[dict[str, Any]],
error: Optional[str] = None,
) -> None:
"""Record retrieval operation completion."""
start_data = self._retrieval_starts.pop(retrieval_id, None)
if start_data:
start_time, retrieval_trace = start_data
retrieval_trace.chunks = chunks
retrieval_trace.chunks_retrieved = len(chunks)
retrieval_trace.latency_ms = (time.perf_counter() - start_time) * 1000
retrieval_trace.completed_at = datetime.utcnow()
retrieval_trace.error = error
self.session.add_retrieval(retrieval_trace)
event = self._create_event(
event_type=TraceEventType.RETRIEVAL_COMPLETED,
error=error,
metadata={"retrieval_id": retrieval_id, "chunks_count": len(chunks)},
)
self._record_event(event)
def tool_call_started(
self,
call_id: str,
tool_name: str,
tool_type: str,
arguments: dict[str, Any],
) -> None:
"""Record tool call start."""
tool_trace = ToolCallTrace(
call_id=call_id,
tool_name=tool_name,
tool_type=tool_type,
arguments=arguments,
)
self._tool_call_starts[call_id] = (time.perf_counter(), tool_trace)
event = self._create_event(
event_type=TraceEventType.TOOL_CALL_STARTED,
metadata={"call_id": call_id, "tool_name": tool_name},
)
self._record_event(event)
def tool_call_completed(
self,
call_id: str,
result: Any,
error: Optional[str] = None,
) -> None:
"""Record tool call completion."""
start_data = self._tool_call_starts.pop(call_id, None)
if start_data:
start_time, tool_trace = start_data
tool_trace.result = result
tool_trace.latency_ms = (time.perf_counter() - start_time) * 1000
tool_trace.completed_at = datetime.utcnow()
tool_trace.error = error
self.session.add_tool_call(tool_trace)
event = self._create_event(
event_type=TraceEventType.TOOL_CALL_COMPLETED,
error=error,
metadata={"call_id": call_id},
)
self._record_event(event)
def message_generated(
self,
content: str,
component_id: Optional[str] = None,
component_name: Optional[str] = None,
) -> None:
"""Record message generation event."""
event = self._create_event(
event_type=TraceEventType.MESSAGE_GENERATED,
component_id=component_id,
component_name=component_name,
outputs={"content": content[:500] if len(content) > 500 else content},
)
self._record_event(event)
def thinking_started(self, component_id: str, thoughts: str) -> None:
"""Record thinking/reasoning start."""
if not self._should_trace(TraceLevel.DETAILED):
return
event = self._create_event(
event_type=TraceEventType.THINKING_STARTED,
component_id=component_id,
thoughts=thoughts,
)
self._record_event(event)
def thinking_completed(self, component_id: str, thoughts: str) -> None:
"""Record thinking/reasoning completion."""
if not self._should_trace(TraceLevel.DETAILED):
return
event = self._create_event(
event_type=TraceEventType.THINKING_COMPLETED,
component_id=component_id,
thoughts=thoughts,
)
self._record_event(event)
def error_occurred(
self,
error: str,
component_id: Optional[str] = None,
component_name: Optional[str] = None,
) -> None:
"""Record an error event."""
event = self._create_event(
event_type=TraceEventType.ERROR_OCCURRED,
component_id=component_id,
component_name=component_name,
error=error,
)
self._record_event(event)
def get_session(self) -> TraceSession:
"""Get the current trace session."""
return self.session
def get_events(self) -> list[TraceEvent]:
"""Get all recorded events."""
return self.session.events
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the trace session."""
return self.session.get_summary()
def close(self) -> None:
"""Close the trace collector and cleanup resources."""
self._is_active = False
remove_trace_collector(self.task_id)
@contextmanager
def trace_component(
self,
component_id: str,
component_name: str,
component_type: str,
inputs: Optional[dict[str, Any]] = None,
) -> Generator[None, None, None]:
"""Context manager for tracing component execution."""
self.node_started(component_id, component_name, component_type, inputs)
error = None
outputs = None
try:
yield
except Exception as e:
error = str(e)
raise
finally:
self.node_finished(
component_id, component_name, component_type,
inputs=inputs, outputs=outputs, error=error
)

View file

@ -0,0 +1,416 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Formatter for Agent Execution Logs
This module provides various formatters for converting trace events and sessions
into different output formats suitable for API responses, logging, and debugging.
"""
import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Optional, Generator
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceEvent,
TraceSession,
LLMCallTrace,
RetrievalTrace,
ToolCallTrace,
)
class TraceFormatter(ABC):
"""Abstract base class for trace formatters."""
@abstractmethod
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a single trace event."""
pass
@abstractmethod
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a complete trace session."""
pass
@abstractmethod
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming."""
pass
class StreamingTraceFormatter(TraceFormatter):
"""Formatter optimized for real-time SSE streaming."""
def __init__(self, include_inputs: bool = True, include_outputs: bool = True):
"""Initialize the streaming formatter."""
self.include_inputs = include_inputs
self.include_outputs = include_outputs
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event for streaming."""
result = {
"event_id": event.event_id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
}
if event.component_id:
result["component_id"] = event.component_id
if event.component_name:
result["component_name"] = event.component_name
if event.component_type:
result["component_type"] = event.component_type
if self.include_inputs and event.inputs:
result["inputs"] = self._truncate_dict(event.inputs, max_length=200)
if self.include_outputs and event.outputs:
result["outputs"] = self._truncate_dict(event.outputs, max_length=500)
if event.error:
result["error"] = event.error
if event.elapsed_time is not None:
result["elapsed_time"] = round(event.elapsed_time, 3)
if event.thoughts:
result["thoughts"] = event.thoughts[:300] if len(event.thoughts) > 300 else event.thoughts
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a trace session for streaming response."""
return {
"session_id": session.session_id,
"status": session.status,
"started_at": session.started_at.isoformat(),
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"total_elapsed_time": round(session.total_elapsed_time, 3),
"summary": session.get_summary(),
"events": [self.format_event(e) for e in session.events],
}
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event as SSE data."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace', 'data': data}, ensure_ascii=False)}\n\n"
def _truncate_dict(self, d: dict[str, Any], max_length: int = 200) -> dict[str, Any]:
"""Truncate string values in a dictionary."""
result = {}
for key, value in d.items():
if isinstance(value, str) and len(value) > max_length:
result[key] = value[:max_length] + "..."
elif isinstance(value, dict):
result[key] = self._truncate_dict(value, max_length)
elif isinstance(value, list) and len(value) > 10:
result[key] = value[:10] + ["..."]
else:
result[key] = value
return result
class CompactTraceFormatter(TraceFormatter):
"""Compact formatter for minimal trace output."""
def __init__(self):
"""Initialize the compact formatter."""
self._event_icons = {
TraceEventType.WORKFLOW_STARTED: "🚀",
TraceEventType.WORKFLOW_COMPLETED: "",
TraceEventType.WORKFLOW_FAILED: "",
TraceEventType.NODE_STARTED: "▶️",
TraceEventType.NODE_FINISHED: "✔️",
TraceEventType.NODE_FAILED: "",
TraceEventType.RETRIEVAL_STARTED: "🔍",
TraceEventType.RETRIEVAL_COMPLETED: "📚",
TraceEventType.LLM_CALL_STARTED: "🤖",
TraceEventType.LLM_CALL_COMPLETED: "💬",
TraceEventType.TOOL_CALL_STARTED: "🔧",
TraceEventType.TOOL_CALL_COMPLETED: "⚙️",
TraceEventType.MESSAGE_GENERATED: "📝",
TraceEventType.ERROR_OCCURRED: "⚠️",
TraceEventType.THINKING_STARTED: "💭",
TraceEventType.THINKING_COMPLETED: "💡",
}
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event in compact form."""
icon = self._event_icons.get(event.event_type, "")
result = {
"type": event.event_type.value,
"icon": icon,
"time": event.timestamp.strftime("%H:%M:%S.%f")[:-3],
}
if event.component_name:
result["component"] = event.component_name
if event.elapsed_time is not None:
result["duration_ms"] = round(event.elapsed_time * 1000, 1)
if event.error:
result["error"] = event.error[:100]
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a trace session in compact form."""
summary = session.get_summary()
return {
"session_id": session.session_id,
"status": session.status,
"duration_s": round(session.total_elapsed_time, 2),
"nodes": summary["nodes_executed"],
"llm_calls": summary["total_llm_calls"],
"retrievals": summary["total_retrievals"],
"tool_calls": summary["total_tool_calls"],
"tokens": summary["total_tokens"],
"errors": summary["errors_count"],
"timeline": [self.format_event(e) for e in session.events],
}
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming in compact form."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace_compact', 'data': data}, ensure_ascii=False)}\n\n"
def format_timeline(self, session: TraceSession) -> list[str]:
"""Format session as a text timeline."""
lines = []
for event in session.events:
icon = self._event_icons.get(event.event_type, "")
time_str = event.timestamp.strftime("%H:%M:%S")
component = event.component_name or ""
duration = f" ({event.elapsed_time*1000:.0f}ms)" if event.elapsed_time else ""
line = f"{time_str} {icon} {event.event_type.value}"
if component:
line += f" [{component}]"
line += duration
if event.error:
line += f" ERROR: {event.error[:50]}"
lines.append(line)
return lines
class DetailedTraceFormatter(TraceFormatter):
"""Detailed formatter for comprehensive trace output."""
def __init__(self, include_raw_data: bool = False):
"""Initialize the detailed formatter."""
self.include_raw_data = include_raw_data
def format_event(self, event: TraceEvent) -> dict[str, Any]:
"""Format a trace event with full details."""
result = {
"event_id": event.event_id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
"timestamp_unix": event.timestamp.timestamp(),
}
if event.component_id:
result["component"] = {
"id": event.component_id,
"name": event.component_name,
"type": event.component_type,
}
if event.inputs is not None:
result["inputs"] = event.inputs
if event.outputs is not None:
result["outputs"] = event.outputs
if event.error:
result["error"] = {
"message": event.error,
"occurred_at": event.timestamp.isoformat(),
}
if event.elapsed_time is not None:
result["timing"] = {
"elapsed_seconds": round(event.elapsed_time, 4),
"elapsed_ms": round(event.elapsed_time * 1000, 2),
}
if event.thoughts:
result["thoughts"] = event.thoughts
if event.metadata:
result["metadata"] = event.metadata
return result
def format_session(self, session: TraceSession) -> dict[str, Any]:
"""Format a complete trace session with all details."""
result = {
"session_id": session.session_id,
"metadata": session.metadata.to_dict(),
"status": session.status,
"timing": {
"started_at": session.started_at.isoformat(),
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"total_elapsed_seconds": round(session.total_elapsed_time, 4),
},
"summary": session.get_summary(),
"events": [self.format_event(e) for e in session.events],
"llm_calls": [self._format_llm_call(c) for c in session.llm_calls],
"retrievals": [self._format_retrieval(r) for r in session.retrievals],
"tool_calls": [self._format_tool_call(t) for t in session.tool_calls],
}
if session.error:
result["error"] = session.error
return result
def format_for_stream(self, event: TraceEvent) -> str:
"""Format an event for SSE streaming with full details."""
data = self.format_event(event)
return f"data:{json.dumps({'event': 'trace_detailed', 'data': data}, ensure_ascii=False)}\n\n"
def _format_llm_call(self, call: LLMCallTrace) -> dict[str, Any]:
"""Format an LLM call trace."""
result = {
"call_id": call.call_id,
"model_name": call.model_name,
"tokens": {
"prompt": call.prompt_tokens,
"completion": call.completion_tokens,
"total": call.total_tokens,
},
"latency_ms": round(call.latency_ms, 2),
"temperature": call.temperature,
"started_at": call.started_at.isoformat(),
"completed_at": call.completed_at.isoformat() if call.completed_at else None,
}
if self.include_raw_data:
result["prompt"] = call.prompt
result["response"] = call.response
else:
result["prompt_preview"] = call.prompt[:200] + "..." if len(call.prompt) > 200 else call.prompt
result["response_preview"] = call.response[:200] + "..." if call.response and len(call.response) > 200 else call.response
if call.max_tokens:
result["max_tokens"] = call.max_tokens
if call.error:
result["error"] = call.error
return result
def _format_retrieval(self, retrieval: RetrievalTrace) -> dict[str, Any]:
"""Format a retrieval trace."""
result = {
"retrieval_id": retrieval.retrieval_id,
"query": retrieval.query,
"knowledge_bases": retrieval.knowledge_bases,
"config": {
"top_k": retrieval.top_k,
"similarity_threshold": retrieval.similarity_threshold,
"rerank_enabled": retrieval.rerank_enabled,
},
"results": {
"chunks_retrieved": retrieval.chunks_retrieved,
"chunks_preview": retrieval.chunks[:3] if self.include_raw_data else [
{"id": c.get("id"), "score": c.get("score")} for c in retrieval.chunks[:3]
],
},
"latency_ms": round(retrieval.latency_ms, 2),
"started_at": retrieval.started_at.isoformat(),
"completed_at": retrieval.completed_at.isoformat() if retrieval.completed_at else None,
}
if retrieval.error:
result["error"] = retrieval.error
return result
def _format_tool_call(self, tool: ToolCallTrace) -> dict[str, Any]:
"""Format a tool call trace."""
result = {
"call_id": tool.call_id,
"tool_name": tool.tool_name,
"tool_type": tool.tool_type,
"arguments": tool.arguments,
"latency_ms": round(tool.latency_ms, 2),
"started_at": tool.started_at.isoformat(),
"completed_at": tool.completed_at.isoformat() if tool.completed_at else None,
}
if self.include_raw_data:
result["result"] = tool.result
else:
result_str = str(tool.result) if tool.result else None
result["result_preview"] = result_str[:200] + "..." if result_str and len(result_str) > 200 else result_str
if tool.error:
result["error"] = tool.error
return result
class TraceFormatterFactory:
"""Factory for creating trace formatters."""
_formatters = {
"streaming": StreamingTraceFormatter,
"compact": CompactTraceFormatter,
"detailed": DetailedTraceFormatter,
}
@classmethod
def create(cls, format_type: str = "streaming", **kwargs) -> TraceFormatter:
"""Create a trace formatter by type."""
formatter_class = cls._formatters.get(format_type)
if not formatter_class:
raise ValueError(f"Unknown formatter type: {format_type}. Available: {list(cls._formatters.keys())}")
return formatter_class(**kwargs)
@classmethod
def register(cls, name: str, formatter_class: type) -> None:
"""Register a custom formatter."""
if not issubclass(formatter_class, TraceFormatter):
raise TypeError("Formatter must be a subclass of TraceFormatter")
cls._formatters[name] = formatter_class
@classmethod
def available_formatters(cls) -> list[str]:
"""Get list of available formatter types."""
return list(cls._formatters.keys())
def format_trace_for_api(
session: TraceSession,
format_type: str = "streaming",
**kwargs
) -> dict[str, Any]:
"""Convenience function to format a trace session for API response."""
formatter = TraceFormatterFactory.create(format_type, **kwargs)
return formatter.format_session(session)
def generate_trace_stream(
events: Generator[TraceEvent, None, None],
format_type: str = "streaming",
**kwargs
) -> Generator[str, None, None]:
"""Generate SSE stream from trace events."""
formatter = TraceFormatterFactory.create(format_type, **kwargs)
for event in events:
yield formatter.format_for_stream(event)

375
agent/trace/trace_models.py Normal file
View file

@ -0,0 +1,375 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Models for Agent Execution Logging
This module provides data models for capturing and representing trace information
during agent execution. It includes models for individual trace events, component
execution details, and complete trace sessions.
"""
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union
import json
import uuid
class TraceEventType(Enum):
"""Enumeration of trace event types."""
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_COMPLETED = "workflow_completed"
WORKFLOW_FAILED = "workflow_failed"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_FAILED = "node_failed"
RETRIEVAL_STARTED = "retrieval_started"
RETRIEVAL_COMPLETED = "retrieval_completed"
LLM_CALL_STARTED = "llm_call_started"
LLM_CALL_COMPLETED = "llm_call_completed"
TOOL_CALL_STARTED = "tool_call_started"
TOOL_CALL_COMPLETED = "tool_call_completed"
MESSAGE_GENERATED = "message_generated"
ERROR_OCCURRED = "error_occurred"
THINKING_STARTED = "thinking_started"
THINKING_COMPLETED = "thinking_completed"
class TraceLevel(Enum):
"""Trace verbosity levels."""
MINIMAL = "minimal"
STANDARD = "standard"
DETAILED = "detailed"
DEBUG = "debug"
@dataclass
class TraceMetadata:
"""Metadata associated with a trace session."""
agent_id: str
session_id: str
user_id: str
tenant_id: str
trace_level: TraceLevel = TraceLevel.STANDARD
created_at: datetime = field(default_factory=datetime.utcnow)
tags: list[str] = field(default_factory=list)
custom_data: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert metadata to dictionary representation."""
return {
"agent_id": self.agent_id,
"session_id": self.session_id,
"user_id": self.user_id,
"tenant_id": self.tenant_id,
"trace_level": self.trace_level.value,
"created_at": self.created_at.isoformat(),
"tags": self.tags,
"custom_data": self.custom_data
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TraceMetadata":
"""Create TraceMetadata from dictionary."""
return cls(
agent_id=data.get("agent_id", ""),
session_id=data.get("session_id", ""),
user_id=data.get("user_id", ""),
tenant_id=data.get("tenant_id", ""),
trace_level=TraceLevel(data.get("trace_level", "standard")),
created_at=datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.utcnow(),
tags=data.get("tags", []),
custom_data=data.get("custom_data", {})
)
@dataclass
class ComponentInfo:
"""Information about a component in the agent workflow."""
component_id: str
component_name: str
component_type: str
params: dict[str, Any] = field(default_factory=dict)
upstream: list[str] = field(default_factory=list)
downstream: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert component info to dictionary."""
return {
"component_id": self.component_id,
"component_name": self.component_name,
"component_type": self.component_type,
"params": self.params,
"upstream": self.upstream,
"downstream": self.downstream
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo":
"""Create ComponentInfo from dictionary."""
return cls(
component_id=data.get("component_id", ""),
component_name=data.get("component_name", ""),
component_type=data.get("component_type", ""),
params=data.get("params", {}),
upstream=data.get("upstream", []),
downstream=data.get("downstream", [])
)
@dataclass
class TraceEvent:
"""Represents a single trace event during agent execution."""
event_id: str
event_type: TraceEventType
timestamp: datetime
component_id: Optional[str] = None
component_name: Optional[str] = None
component_type: Optional[str] = None
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
error: Optional[str] = None
elapsed_time: Optional[float] = None
thoughts: Optional[str] = None
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Initialize event_id if not provided."""
if not self.event_id:
self.event_id = str(uuid.uuid4())
def to_dict(self) -> dict[str, Any]:
"""Convert trace event to dictionary representation."""
result = {
"event_id": self.event_id,
"event_type": self.event_type.value,
"timestamp": self.timestamp.isoformat(),
}
if self.component_id:
result["component_id"] = self.component_id
if self.component_name:
result["component_name"] = self.component_name
if self.component_type:
result["component_type"] = self.component_type
if self.inputs is not None:
result["inputs"] = self.inputs
if self.outputs is not None:
result["outputs"] = self.outputs
if self.error:
result["error"] = self.error
if self.elapsed_time is not None:
result["elapsed_time"] = self.elapsed_time
if self.thoughts:
result["thoughts"] = self.thoughts
if self.metadata:
result["metadata"] = self.metadata
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TraceEvent":
"""Create TraceEvent from dictionary."""
return cls(
event_id=data.get("event_id", ""),
event_type=TraceEventType(data.get("event_type", "node_started")),
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
component_id=data.get("component_id"),
component_name=data.get("component_name"),
component_type=data.get("component_type"),
inputs=data.get("inputs"),
outputs=data.get("outputs"),
error=data.get("error"),
elapsed_time=data.get("elapsed_time"),
thoughts=data.get("thoughts"),
metadata=data.get("metadata", {})
)
def to_json(self) -> str:
"""Convert trace event to JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False)
@dataclass
class LLMCallTrace:
"""Trace information for LLM API calls."""
call_id: str
model_name: str
prompt: str
response: Optional[str] = None
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
latency_ms: float = 0.0
temperature: float = 0.0
max_tokens: Optional[int] = None
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert LLM call trace to dictionary."""
return {
"call_id": self.call_id,
"model_name": self.model_name,
"prompt": self.prompt[:500] + "..." if len(self.prompt) > 500 else self.prompt,
"response": self.response[:500] + "..." if self.response and len(self.response) > 500 else self.response,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"latency_ms": self.latency_ms,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class RetrievalTrace:
"""Trace information for retrieval operations."""
retrieval_id: str
query: str
knowledge_bases: list[str] = field(default_factory=list)
top_k: int = 10
similarity_threshold: float = 0.0
chunks_retrieved: int = 0
chunks: list[dict[str, Any]] = field(default_factory=list)
latency_ms: float = 0.0
rerank_enabled: bool = False
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert retrieval trace to dictionary."""
return {
"retrieval_id": self.retrieval_id,
"query": self.query,
"knowledge_bases": self.knowledge_bases,
"top_k": self.top_k,
"similarity_threshold": self.similarity_threshold,
"chunks_retrieved": self.chunks_retrieved,
"chunks": self.chunks[:5],
"latency_ms": self.latency_ms,
"rerank_enabled": self.rerank_enabled,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class ToolCallTrace:
"""Trace information for tool/function calls."""
call_id: str
tool_name: str
tool_type: str
arguments: dict[str, Any] = field(default_factory=dict)
result: Optional[Any] = None
latency_ms: float = 0.0
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert tool call trace to dictionary."""
result_str = str(self.result)[:500] if self.result else None
return {
"call_id": self.call_id,
"tool_name": self.tool_name,
"tool_type": self.tool_type,
"arguments": self.arguments,
"result": result_str,
"latency_ms": self.latency_ms,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class TraceSession:
"""Complete trace session for an agent execution."""
session_id: str
metadata: TraceMetadata
events: list[TraceEvent] = field(default_factory=list)
llm_calls: list[LLMCallTrace] = field(default_factory=list)
retrievals: list[RetrievalTrace] = field(default_factory=list)
tool_calls: list[ToolCallTrace] = field(default_factory=list)
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
total_elapsed_time: float = 0.0
status: str = "running"
error: Optional[str] = None
def add_event(self, event: TraceEvent) -> None:
"""Add a trace event to the session."""
self.events.append(event)
def add_llm_call(self, llm_call: LLMCallTrace) -> None:
"""Add an LLM call trace to the session."""
self.llm_calls.append(llm_call)
def add_retrieval(self, retrieval: RetrievalTrace) -> None:
"""Add a retrieval trace to the session."""
self.retrievals.append(retrieval)
def add_tool_call(self, tool_call: ToolCallTrace) -> None:
"""Add a tool call trace to the session."""
self.tool_calls.append(tool_call)
def complete(self, error: Optional[str] = None) -> None:
"""Mark the session as completed."""
self.completed_at = datetime.utcnow()
self.total_elapsed_time = (self.completed_at - self.started_at).total_seconds()
self.status = "failed" if error else "completed"
self.error = error
def to_dict(self) -> dict[str, Any]:
"""Convert trace session to dictionary representation."""
return {
"session_id": self.session_id,
"metadata": self.metadata.to_dict(),
"events": [e.to_dict() for e in self.events],
"llm_calls": [c.to_dict() for c in self.llm_calls],
"retrievals": [r.to_dict() for r in self.retrievals],
"tool_calls": [t.to_dict() for t in self.tool_calls],
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"total_elapsed_time": self.total_elapsed_time,
"status": self.status,
"error": self.error,
"summary": self.get_summary()
}
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the trace session."""
return {
"total_events": len(self.events),
"total_llm_calls": len(self.llm_calls),
"total_retrievals": len(self.retrievals),
"total_tool_calls": len(self.tool_calls),
"total_tokens": sum(c.total_tokens for c in self.llm_calls),
"total_chunks_retrieved": sum(r.chunks_retrieved for r in self.retrievals),
"nodes_executed": len(set(e.component_id for e in self.events if e.component_id)),
"errors_count": len([e for e in self.events if e.error])
}
def to_json(self) -> str:
"""Convert trace session to JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)

477
api/apps/trace_app.py Normal file
View file

@ -0,0 +1,477 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Agent Trace API Endpoints
Provides REST API endpoints for accessing agent execution traces,
including trace retrieval, filtering, analysis, and management.
This addresses Issue #10081: Add Trace Logging for Agent Completions API.
"""
from datetime import datetime
from quart import request, Response
import json
from api.apps import login_required, current_user
from api.db.services.trace_service import TraceService
from api.utils.api_utils import (
get_data_error_result,
get_json_result,
get_request_json,
server_error_response,
validate_request,
)
from common.constants import RetCode
@manager.route('/traces', methods=['GET']) # noqa: F821
@login_required
async def list_traces():
"""
List trace sessions for the current tenant.
Query parameters:
- agent_id: Filter by agent ID (optional)
- user_id: Filter by user ID (optional)
- status: Filter by status (running, completed, failed) (optional)
- start_time: Filter by start time ISO format (optional)
- end_time: Filter by end time ISO format (optional)
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
Returns:
Paginated list of trace sessions
"""
try:
agent_id = request.args.get("agent_id")
user_id = request.args.get("user_id")
status = request.args.get("status")
start_time_str = request.args.get("start_time")
end_time_str = request.args.get("end_time")
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
start_time = None
end_time = None
if start_time_str:
start_time = datetime.fromisoformat(start_time_str)
if end_time_str:
end_time = datetime.fromisoformat(end_time_str)
result = TraceService.list_traces(
tenant_id=current_user.id,
agent_id=agent_id,
user_id=user_id,
status=status,
start_time=start_time,
end_time=end_time,
page=page,
page_size=page_size,
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>', methods=['GET']) # noqa: F821
@login_required
async def get_trace(task_id):
"""
Get a specific trace session by task ID.
Path parameters:
- task_id: The task/trace session ID
Query parameters:
- format: Output format (streaming, compact, detailed) (default: streaming)
Returns:
Trace session data
"""
try:
format_type = request.args.get("format", "streaming")
result = TraceService.format_trace(task_id, format_type)
if not result:
return get_data_error_result(
message="Trace session not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>/events', methods=['GET']) # noqa: F821
@login_required
async def get_trace_events(task_id):
"""
Get trace events for a specific session.
Path parameters:
- task_id: The task/trace session ID
Query parameters:
- event_types: Comma-separated list of event types to filter (optional)
- component_id: Filter by component ID (optional)
- limit: Maximum number of events (default: 100)
- offset: Number of events to skip (default: 0)
Returns:
List of trace events
"""
try:
event_types_str = request.args.get("event_types")
event_types = event_types_str.split(",") if event_types_str else None
component_id = request.args.get("component_id")
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
events = TraceService.get_trace_events(
task_id=task_id,
event_types=event_types,
component_id=component_id,
limit=limit,
offset=offset,
)
return get_json_result(data={"events": events, "count": len(events)})
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>/summary', methods=['GET']) # noqa: F821
@login_required
async def get_trace_summary(task_id):
"""
Get a summary of a trace session.
Path parameters:
- task_id: The task/trace session ID
Returns:
Trace session summary
"""
try:
summary = TraceService.get_trace_summary(task_id)
if not summary:
return get_data_error_result(
message="Trace session not found",
code=RetCode.DATA_ERROR
)
return get_json_result(data=summary)
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>/analysis', methods=['GET']) # noqa: F821
@login_required
async def analyze_trace(task_id):
"""
Analyze a trace session and get insights.
Path parameters:
- task_id: The task/trace session ID
Returns:
Analysis results including bottlenecks, errors, and recommendations
"""
try:
analysis = TraceService.analyze_trace(task_id)
if not analysis:
return get_data_error_result(
message="Trace session not found or analysis failed",
code=RetCode.DATA_ERROR
)
return get_json_result(data=analysis)
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>', methods=['DELETE']) # noqa: F821
@login_required
async def delete_trace(task_id):
"""
Delete a trace session.
Path parameters:
- task_id: The task/trace session ID
Returns:
Success status
"""
try:
success, message = TraceService.delete_trace(task_id)
if not success:
return get_data_error_result(message=message)
return get_json_result(data={"task_id": task_id, "message": message})
except Exception as e:
return server_error_response(e)
@manager.route('/traces/cleanup', methods=['POST']) # noqa: F821
@login_required
async def cleanup_traces():
"""
Clean up old trace sessions.
Request body:
{
"days": 7 // Number of days to keep traces (default: 7)
}
Returns:
Number of deleted traces
"""
try:
req = await get_request_json()
days = req.get("days", 7)
deleted, message = TraceService.cleanup_old_traces(days)
return get_json_result(data={"deleted": deleted, "message": message})
except Exception as e:
return server_error_response(e)
@manager.route('/traces/<task_id>/stream', methods=['GET']) # noqa: F821
@login_required
async def stream_trace(task_id):
"""
Stream trace events in real-time using Server-Sent Events.
Path parameters:
- task_id: The task/trace session ID
Query parameters:
- format: Output format (streaming, compact, detailed) (default: streaming)
Returns:
SSE stream of trace events
"""
try:
format_type = request.args.get("format", "streaming")
from agent.trace.trace_collector import get_trace_collector
from agent.trace.trace_formatter import TraceFormatterFactory
collector = get_trace_collector(task_id)
if not collector:
return get_data_error_result(
message="Active trace session not found",
code=RetCode.DATA_ERROR
)
formatter = TraceFormatterFactory.create(format_type)
async def generate():
import asyncio
for event in collector.get_events():
yield formatter.format_for_stream(event)
event_queue = []
def on_event(event):
event_queue.append(event)
collector.subscribe(on_event)
try:
while collector._is_active:
while event_queue:
event = event_queue.pop(0)
yield formatter.format_for_stream(event)
await asyncio.sleep(0.1)
finally:
collector.unsubscribe(on_event)
yield "data:[DONE]\n\n"
resp = Response(generate(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
except Exception as e:
return server_error_response(e)
@manager.route('/agents/<agent_id>/traces', methods=['GET']) # noqa: F821
@login_required
async def list_agent_traces(agent_id):
"""
List trace sessions for a specific agent.
Path parameters:
- agent_id: The agent ID
Query parameters:
- status: Filter by status (optional)
- page: Page number (default: 1)
- page_size: Items per page (default: 20)
Returns:
Paginated list of trace sessions for the agent
"""
try:
status = request.args.get("status")
page = int(request.args.get("page", 1))
page_size = int(request.args.get("page_size", 20))
result = TraceService.list_traces(
tenant_id=current_user.id,
agent_id=agent_id,
status=status,
page=page,
page_size=page_size,
)
return get_json_result(data=result)
except Exception as e:
return server_error_response(e)
@manager.route('/agents/<agent_id>/completions/trace', methods=['POST']) # noqa: F821
@login_required
@validate_request("question")
async def agent_completion_with_trace(agent_id):
"""
Execute agent completion with trace logging enabled.
This endpoint is similar to /agents/<agent_id>/completions but includes
trace information in the response, addressing Issue #10081.
Path parameters:
- agent_id: The agent ID
Request body:
{
"question": "User question",
"session_id": "Optional session ID",
"stream": true,
"trace_level": "standard", // minimal, standard, detailed, debug
"include_trace": true
}
Returns:
Agent response with trace information
"""
try:
from api.db.services.canvas_service import completion as agent_completion
req = await get_request_json()
stream = req.get("stream", True)
trace_level = req.get("trace_level", "standard")
include_trace = req.get("include_trace", True)
from common.misc_utils import get_uuid
task_id = get_uuid()
if include_trace:
success, trace_id = TraceService.create_trace_session(
task_id=task_id,
agent_id=agent_id,
session_id=req.get("session_id", ""),
user_id=current_user.id,
tenant_id=current_user.id,
trace_level=trace_level,
)
if stream:
async def generate():
full_content = ""
reference = {}
async for answer in agent_completion(
tenant_id=current_user.id,
agent_id=agent_id,
**req
):
try:
ans = json.loads(answer[5:])
if ans["event"] == "message":
full_content += ans["data"]["content"]
if ans.get("data", {}).get("reference"):
reference.update(ans["data"]["reference"])
yield answer
except Exception:
continue
if include_trace:
TraceService.save_trace_session(task_id)
trace_data = TraceService.format_trace(task_id, "compact")
yield f"data:{json.dumps({'event': 'trace', 'data': trace_data}, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"
resp = Response(generate(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
full_content = ""
reference = {}
final_ans = None
async for answer in agent_completion(
tenant_id=current_user.id,
agent_id=agent_id,
**req
):
try:
ans = json.loads(answer[5:])
if ans["event"] == "message":
full_content += ans["data"]["content"]
if ans.get("data", {}).get("reference"):
reference.update(ans["data"]["reference"])
final_ans = ans
except Exception:
continue
if final_ans:
final_ans["data"]["content"] = full_content
final_ans["data"]["reference"] = reference
if include_trace:
TraceService.save_trace_session(task_id)
trace_data = TraceService.format_trace(task_id, "compact")
if final_ans:
final_ans["trace"] = trace_data
return get_json_result(data=final_ans)
except Exception as e:
return server_error_response(e)

View file

@ -0,0 +1,482 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Service for Agent Execution Logging
This module provides the TraceService class that manages trace data persistence,
retrieval, and analysis. It integrates with the trace collector and formatter
modules to provide a complete tracing solution.
"""
import json
import logging
import time
from datetime import datetime, timedelta
from typing import Any, Optional, Tuple
from collections import defaultdict
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceSession,
TraceEvent,
TraceMetadata,
)
from agent.trace.trace_collector import (
TraceCollector,
get_trace_collector,
create_trace_collector,
)
from agent.trace.trace_formatter import (
TraceFormatterFactory,
format_trace_for_api,
)
from rag.utils.redis_conn import REDIS_CONN
TRACE_KEY_PREFIX = "agent_trace:"
TRACE_SESSION_TTL = 86400 * 7
TRACE_EVENT_TTL = 86400 * 3
class TraceService:
"""
Service for managing agent execution traces.
Provides methods for creating, storing, retrieving, and analyzing
trace data from agent executions.
"""
@staticmethod
def create_trace_session(
task_id: str,
agent_id: str,
session_id: str,
user_id: str,
tenant_id: str,
trace_level: str = "standard",
) -> Tuple[bool, str]:
"""
Create a new trace session for an agent execution.
Args:
task_id: Unique identifier for the task
agent_id: ID of the agent being executed
session_id: ID of the conversation session
user_id: ID of the user
tenant_id: ID of the tenant
trace_level: Verbosity level (minimal, standard, detailed, debug)
Returns:
Tuple of (success, trace_id or error message)
"""
try:
level = TraceLevel(trace_level)
except ValueError:
level = TraceLevel.STANDARD
try:
collector = create_trace_collector(
task_id=task_id,
agent_id=agent_id,
session_id=session_id,
user_id=user_id,
tenant_id=tenant_id,
trace_level=level,
)
session_data = {
"task_id": task_id,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"tenant_id": tenant_id,
"trace_level": trace_level,
"created_at": datetime.utcnow().isoformat(),
"status": "running",
}
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_data))
return True, task_id
except Exception as e:
logging.exception(f"Failed to create trace session: {e}")
return False, str(e)
@staticmethod
def get_trace_collector(task_id: str) -> Optional[TraceCollector]:
"""Get an active trace collector by task ID."""
return get_trace_collector(task_id)
@staticmethod
def save_trace_session(task_id: str) -> Tuple[bool, str]:
"""
Save the current trace session to persistent storage.
Args:
task_id: ID of the task/trace session
Returns:
Tuple of (success, message)
"""
try:
collector = get_trace_collector(task_id)
if not collector:
return False, "Trace collector not found"
session = collector.get_session()
session_dict = session.to_dict()
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_dict))
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
events_data = json.dumps([e.to_dict() for e in session.events])
REDIS_CONN.setex(events_key, TRACE_EVENT_TTL, events_data)
return True, "Trace session saved successfully"
except Exception as e:
logging.exception(f"Failed to save trace session: {e}")
return False, str(e)
@staticmethod
def get_trace_session(task_id: str) -> Optional[dict[str, Any]]:
"""
Retrieve a trace session by task ID.
Args:
task_id: ID of the task/trace session
Returns:
Trace session data or None if not found
"""
try:
collector = get_trace_collector(task_id)
if collector:
return collector.get_session().to_dict()
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
data = REDIS_CONN.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logging.exception(f"Failed to get trace session: {e}")
return None
@staticmethod
def get_trace_events(
task_id: str,
event_types: Optional[list[str]] = None,
component_id: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
"""
Retrieve trace events with optional filtering.
Args:
task_id: ID of the task/trace session
event_types: Filter by event types
component_id: Filter by component ID
limit: Maximum number of events to return
offset: Number of events to skip
Returns:
List of trace events
"""
try:
collector = get_trace_collector(task_id)
if collector:
events = collector.get_events()
else:
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
data = REDIS_CONN.get(events_key)
if not data:
return []
events = [TraceEvent.from_dict(e) for e in json.loads(data)]
if event_types:
type_set = set(event_types)
events = [e for e in events if e.event_type.value in type_set]
if component_id:
events = [e for e in events if e.component_id == component_id]
events = events[offset:offset + limit]
return [e.to_dict() for e in events]
except Exception as e:
logging.exception(f"Failed to get trace events: {e}")
return []
@staticmethod
def get_trace_summary(task_id: str) -> Optional[dict[str, Any]]:
"""
Get a summary of the trace session.
Args:
task_id: ID of the task/trace session
Returns:
Summary data or None if not found
"""
try:
collector = get_trace_collector(task_id)
if collector:
return collector.get_summary()
session_data = TraceService.get_trace_session(task_id)
if session_data and "summary" in session_data:
return session_data["summary"]
return None
except Exception as e:
logging.exception(f"Failed to get trace summary: {e}")
return None
@staticmethod
def format_trace(
task_id: str,
format_type: str = "streaming",
**kwargs
) -> Optional[dict[str, Any]]:
"""
Format a trace session using the specified formatter.
Args:
task_id: ID of the task/trace session
format_type: Type of formatter (streaming, compact, detailed)
**kwargs: Additional formatter options
Returns:
Formatted trace data or None if not found
"""
try:
session_data = TraceService.get_trace_session(task_id)
if not session_data:
return None
collector = get_trace_collector(task_id)
if collector:
session = collector.get_session()
return format_trace_for_api(session, format_type, **kwargs)
return session_data
except Exception as e:
logging.exception(f"Failed to format trace: {e}")
return None
@staticmethod
def list_traces(
tenant_id: str,
agent_id: Optional[str] = None,
user_id: Optional[str] = None,
status: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
page: int = 1,
page_size: int = 20,
) -> dict[str, Any]:
"""
List trace sessions with filtering and pagination.
Args:
tenant_id: ID of the tenant
agent_id: Filter by agent ID
user_id: Filter by user ID
status: Filter by status (running, completed, failed)
start_time: Filter by start time
end_time: Filter by end time
page: Page number
page_size: Items per page
Returns:
Paginated list of trace sessions
"""
try:
pattern = f"{TRACE_KEY_PREFIX}session:*"
keys = REDIS_CONN.keys(pattern)
sessions = []
for key in keys:
data = REDIS_CONN.get(key)
if data:
session = json.loads(data)
if session.get("tenant_id") == tenant_id:
sessions.append(session)
if agent_id:
sessions = [s for s in sessions if s.get("agent_id") == agent_id]
if user_id:
sessions = [s for s in sessions if s.get("user_id") == user_id]
if status:
sessions = [s for s in sessions if s.get("status") == status]
if start_time:
sessions = [s for s in sessions
if datetime.fromisoformat(s.get("created_at", "")) >= start_time]
if end_time:
sessions = [s for s in sessions
if datetime.fromisoformat(s.get("created_at", "")) <= end_time]
sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
total = len(sessions)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated = sessions[start_idx:end_idx]
return {
"total": total,
"page": page,
"page_size": page_size,
"sessions": paginated,
}
except Exception as e:
logging.exception(f"Failed to list traces: {e}")
return {"total": 0, "page": page, "page_size": page_size, "sessions": []}
@staticmethod
def delete_trace(task_id: str) -> Tuple[bool, str]:
"""
Delete a trace session.
Args:
task_id: ID of the task/trace session
Returns:
Tuple of (success, message)
"""
try:
session_key = f"{TRACE_KEY_PREFIX}session:{task_id}"
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
REDIS_CONN.delete(session_key)
REDIS_CONN.delete(events_key)
return True, "Trace deleted successfully"
except Exception as e:
logging.exception(f"Failed to delete trace: {e}")
return False, str(e)
@staticmethod
def analyze_trace(task_id: str) -> Optional[dict[str, Any]]:
"""
Analyze a trace session and provide insights.
Args:
task_id: ID of the task/trace session
Returns:
Analysis results or None if not found
"""
try:
session_data = TraceService.get_trace_session(task_id)
if not session_data:
return None
events = TraceService.get_trace_events(task_id, limit=1000)
analysis = {
"task_id": task_id,
"total_events": len(events),
"event_distribution": defaultdict(int),
"component_execution_times": {},
"bottlenecks": [],
"errors": [],
"recommendations": [],
}
for event in events:
event_type = event.get("event_type", "unknown")
analysis["event_distribution"][event_type] += 1
if event.get("error"):
analysis["errors"].append({
"component": event.get("component_name"),
"error": event.get("error"),
"timestamp": event.get("timestamp"),
})
if event.get("elapsed_time"):
component = event.get("component_name", "unknown")
if component not in analysis["component_execution_times"]:
analysis["component_execution_times"][component] = []
analysis["component_execution_times"][component].append(
event.get("elapsed_time")
)
for component, times in analysis["component_execution_times"].items():
avg_time = sum(times) / len(times)
if avg_time > 5.0:
analysis["bottlenecks"].append({
"component": component,
"avg_execution_time": round(avg_time, 2),
"executions": len(times),
})
if analysis["bottlenecks"]:
analysis["recommendations"].append(
"Consider optimizing slow components or adding caching"
)
if len(analysis["errors"]) > 0:
analysis["recommendations"].append(
"Review and fix error-prone components"
)
analysis["event_distribution"] = dict(analysis["event_distribution"])
return analysis
except Exception as e:
logging.exception(f"Failed to analyze trace: {e}")
return None
@staticmethod
def cleanup_old_traces(days: int = 7) -> Tuple[int, str]:
"""
Clean up trace sessions older than specified days.
Args:
days: Number of days to keep traces
Returns:
Tuple of (deleted count, message)
"""
try:
cutoff = datetime.utcnow() - timedelta(days=days)
pattern = f"{TRACE_KEY_PREFIX}session:*"
keys = REDIS_CONN.keys(pattern)
deleted = 0
for key in keys:
data = REDIS_CONN.get(key)
if data:
session = json.loads(data)
created_at = session.get("created_at")
if created_at:
session_time = datetime.fromisoformat(created_at)
if session_time < cutoff:
task_id = key.split(":")[-1]
TraceService.delete_trace(task_id)
deleted += 1
return deleted, f"Deleted {deleted} old trace sessions"
except Exception as e:
logging.exception(f"Failed to cleanup traces: {e}")
return 0, str(e)