Implement comprehensive trace logging system for agent execution that returns step-by-step execution traces in API responses. New modules: - agent/trace/trace_models.py: Data models for trace events, sessions, LLM calls, retrievals, and tool calls - agent/trace/trace_collector.py: Real-time trace event collection with subscriber pattern for streaming - agent/trace/trace_formatter.py: Multiple formatters (streaming, compact, detailed) for different output needs - api/db/services/trace_service.py: Service layer for trace persistence, retrieval, and analysis - api/apps/trace_app.py: REST API endpoints for trace management Features: - Real-time trace streaming via SSE - Multiple trace verbosity levels (minimal, standard, detailed, debug) - Component execution timing and bottleneck detection - LLM call tracking with token counts - Retrieval operation logging with chunk details - Tool call tracing with arguments and results - Trace session persistence in Redis - Analysis and recommendations based on trace data API Endpoints: - GET /traces - List trace sessions - GET /traces/<task_id> - Get trace session - GET /traces/<task_id>/events - Get filtered events - GET /traces/<task_id>/summary - Get trace summary - GET /traces/<task_id>/analysis - Analyze trace - GET /traces/<task_id>/stream - Stream trace events - DELETE /traces/<task_id> - Delete trace - POST /traces/cleanup - Cleanup old traces - POST /agents/<agent_id>/completions/trace - Completion with trace Closes #10081
482 lines
16 KiB
Python
482 lines
16 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""
|
|
Trace Service for Agent Execution Logging
|
|
|
|
This module provides the TraceService class that manages trace data persistence,
|
|
retrieval, and analysis. It integrates with the trace collector and formatter
|
|
modules to provide a complete tracing solution.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Optional, Tuple
|
|
from collections import defaultdict
|
|
|
|
from agent.trace.trace_models import (
|
|
TraceEventType,
|
|
TraceLevel,
|
|
TraceSession,
|
|
TraceEvent,
|
|
TraceMetadata,
|
|
)
|
|
from agent.trace.trace_collector import (
|
|
TraceCollector,
|
|
get_trace_collector,
|
|
create_trace_collector,
|
|
)
|
|
from agent.trace.trace_formatter import (
|
|
TraceFormatterFactory,
|
|
format_trace_for_api,
|
|
)
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
|
|
|
|
TRACE_KEY_PREFIX = "agent_trace:"
|
|
TRACE_SESSION_TTL = 86400 * 7
|
|
TRACE_EVENT_TTL = 86400 * 3
|
|
|
|
|
|
class TraceService:
|
|
"""
|
|
Service for managing agent execution traces.
|
|
|
|
Provides methods for creating, storing, retrieving, and analyzing
|
|
trace data from agent executions.
|
|
"""
|
|
|
|
@staticmethod
|
|
def create_trace_session(
|
|
task_id: str,
|
|
agent_id: str,
|
|
session_id: str,
|
|
user_id: str,
|
|
tenant_id: str,
|
|
trace_level: str = "standard",
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Create a new trace session for an agent execution.
|
|
|
|
Args:
|
|
task_id: Unique identifier for the task
|
|
agent_id: ID of the agent being executed
|
|
session_id: ID of the conversation session
|
|
user_id: ID of the user
|
|
tenant_id: ID of the tenant
|
|
trace_level: Verbosity level (minimal, standard, detailed, debug)
|
|
|
|
Returns:
|
|
Tuple of (success, trace_id or error message)
|
|
"""
|
|
try:
|
|
level = TraceLevel(trace_level)
|
|
except ValueError:
|
|
level = TraceLevel.STANDARD
|
|
|
|
try:
|
|
collector = create_trace_collector(
|
|
task_id=task_id,
|
|
agent_id=agent_id,
|
|
session_id=session_id,
|
|
user_id=user_id,
|
|
tenant_id=tenant_id,
|
|
trace_level=level,
|
|
)
|
|
|
|
session_data = {
|
|
"task_id": task_id,
|
|
"agent_id": agent_id,
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"tenant_id": tenant_id,
|
|
"trace_level": trace_level,
|
|
"created_at": datetime.utcnow().isoformat(),
|
|
"status": "running",
|
|
}
|
|
|
|
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
|
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_data))
|
|
|
|
return True, task_id
|
|
except Exception as e:
|
|
logging.exception(f"Failed to create trace session: {e}")
|
|
return False, str(e)
|
|
|
|
@staticmethod
|
|
def get_trace_collector(task_id: str) -> Optional[TraceCollector]:
|
|
"""Get an active trace collector by task ID."""
|
|
return get_trace_collector(task_id)
|
|
|
|
@staticmethod
|
|
def save_trace_session(task_id: str) -> Tuple[bool, str]:
|
|
"""
|
|
Save the current trace session to persistent storage.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
|
|
Returns:
|
|
Tuple of (success, message)
|
|
"""
|
|
try:
|
|
collector = get_trace_collector(task_id)
|
|
if not collector:
|
|
return False, "Trace collector not found"
|
|
|
|
session = collector.get_session()
|
|
session_dict = session.to_dict()
|
|
|
|
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
|
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_dict))
|
|
|
|
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
|
events_data = json.dumps([e.to_dict() for e in session.events])
|
|
REDIS_CONN.setex(events_key, TRACE_EVENT_TTL, events_data)
|
|
|
|
return True, "Trace session saved successfully"
|
|
except Exception as e:
|
|
logging.exception(f"Failed to save trace session: {e}")
|
|
return False, str(e)
|
|
|
|
@staticmethod
|
|
def get_trace_session(task_id: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Retrieve a trace session by task ID.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
|
|
Returns:
|
|
Trace session data or None if not found
|
|
"""
|
|
try:
|
|
collector = get_trace_collector(task_id)
|
|
if collector:
|
|
return collector.get_session().to_dict()
|
|
|
|
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
|
data = REDIS_CONN.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
|
|
return None
|
|
except Exception as e:
|
|
logging.exception(f"Failed to get trace session: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_trace_events(
|
|
task_id: str,
|
|
event_types: Optional[list[str]] = None,
|
|
component_id: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Retrieve trace events with optional filtering.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
event_types: Filter by event types
|
|
component_id: Filter by component ID
|
|
limit: Maximum number of events to return
|
|
offset: Number of events to skip
|
|
|
|
Returns:
|
|
List of trace events
|
|
"""
|
|
try:
|
|
collector = get_trace_collector(task_id)
|
|
if collector:
|
|
events = collector.get_events()
|
|
else:
|
|
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
|
data = REDIS_CONN.get(events_key)
|
|
if not data:
|
|
return []
|
|
events = [TraceEvent.from_dict(e) for e in json.loads(data)]
|
|
|
|
if event_types:
|
|
type_set = set(event_types)
|
|
events = [e for e in events if e.event_type.value in type_set]
|
|
|
|
if component_id:
|
|
events = [e for e in events if e.component_id == component_id]
|
|
|
|
events = events[offset:offset + limit]
|
|
return [e.to_dict() for e in events]
|
|
except Exception as e:
|
|
logging.exception(f"Failed to get trace events: {e}")
|
|
return []
|
|
|
|
@staticmethod
|
|
def get_trace_summary(task_id: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Get a summary of the trace session.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
|
|
Returns:
|
|
Summary data or None if not found
|
|
"""
|
|
try:
|
|
collector = get_trace_collector(task_id)
|
|
if collector:
|
|
return collector.get_summary()
|
|
|
|
session_data = TraceService.get_trace_session(task_id)
|
|
if session_data and "summary" in session_data:
|
|
return session_data["summary"]
|
|
|
|
return None
|
|
except Exception as e:
|
|
logging.exception(f"Failed to get trace summary: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def format_trace(
|
|
task_id: str,
|
|
format_type: str = "streaming",
|
|
**kwargs
|
|
) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Format a trace session using the specified formatter.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
format_type: Type of formatter (streaming, compact, detailed)
|
|
**kwargs: Additional formatter options
|
|
|
|
Returns:
|
|
Formatted trace data or None if not found
|
|
"""
|
|
try:
|
|
session_data = TraceService.get_trace_session(task_id)
|
|
if not session_data:
|
|
return None
|
|
|
|
collector = get_trace_collector(task_id)
|
|
if collector:
|
|
session = collector.get_session()
|
|
return format_trace_for_api(session, format_type, **kwargs)
|
|
|
|
return session_data
|
|
except Exception as e:
|
|
logging.exception(f"Failed to format trace: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def list_traces(
|
|
tenant_id: str,
|
|
agent_id: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
status: Optional[str] = None,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
List trace sessions with filtering and pagination.
|
|
|
|
Args:
|
|
tenant_id: ID of the tenant
|
|
agent_id: Filter by agent ID
|
|
user_id: Filter by user ID
|
|
status: Filter by status (running, completed, failed)
|
|
start_time: Filter by start time
|
|
end_time: Filter by end time
|
|
page: Page number
|
|
page_size: Items per page
|
|
|
|
Returns:
|
|
Paginated list of trace sessions
|
|
"""
|
|
try:
|
|
pattern = f"{TRACE_KEY_PREFIX}session:*"
|
|
keys = REDIS_CONN.keys(pattern)
|
|
|
|
sessions = []
|
|
for key in keys:
|
|
data = REDIS_CONN.get(key)
|
|
if data:
|
|
session = json.loads(data)
|
|
if session.get("tenant_id") == tenant_id:
|
|
sessions.append(session)
|
|
|
|
if agent_id:
|
|
sessions = [s for s in sessions if s.get("agent_id") == agent_id]
|
|
if user_id:
|
|
sessions = [s for s in sessions if s.get("user_id") == user_id]
|
|
if status:
|
|
sessions = [s for s in sessions if s.get("status") == status]
|
|
|
|
if start_time:
|
|
sessions = [s for s in sessions
|
|
if datetime.fromisoformat(s.get("created_at", "")) >= start_time]
|
|
if end_time:
|
|
sessions = [s for s in sessions
|
|
if datetime.fromisoformat(s.get("created_at", "")) <= end_time]
|
|
|
|
sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
|
|
|
total = len(sessions)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = start_idx + page_size
|
|
paginated = sessions[start_idx:end_idx]
|
|
|
|
return {
|
|
"total": total,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"sessions": paginated,
|
|
}
|
|
except Exception as e:
|
|
logging.exception(f"Failed to list traces: {e}")
|
|
return {"total": 0, "page": page, "page_size": page_size, "sessions": []}
|
|
|
|
@staticmethod
|
|
def delete_trace(task_id: str) -> Tuple[bool, str]:
|
|
"""
|
|
Delete a trace session.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
|
|
Returns:
|
|
Tuple of (success, message)
|
|
"""
|
|
try:
|
|
session_key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
|
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
|
|
|
REDIS_CONN.delete(session_key)
|
|
REDIS_CONN.delete(events_key)
|
|
|
|
return True, "Trace deleted successfully"
|
|
except Exception as e:
|
|
logging.exception(f"Failed to delete trace: {e}")
|
|
return False, str(e)
|
|
|
|
@staticmethod
|
|
def analyze_trace(task_id: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Analyze a trace session and provide insights.
|
|
|
|
Args:
|
|
task_id: ID of the task/trace session
|
|
|
|
Returns:
|
|
Analysis results or None if not found
|
|
"""
|
|
try:
|
|
session_data = TraceService.get_trace_session(task_id)
|
|
if not session_data:
|
|
return None
|
|
|
|
events = TraceService.get_trace_events(task_id, limit=1000)
|
|
|
|
analysis = {
|
|
"task_id": task_id,
|
|
"total_events": len(events),
|
|
"event_distribution": defaultdict(int),
|
|
"component_execution_times": {},
|
|
"bottlenecks": [],
|
|
"errors": [],
|
|
"recommendations": [],
|
|
}
|
|
|
|
for event in events:
|
|
event_type = event.get("event_type", "unknown")
|
|
analysis["event_distribution"][event_type] += 1
|
|
|
|
if event.get("error"):
|
|
analysis["errors"].append({
|
|
"component": event.get("component_name"),
|
|
"error": event.get("error"),
|
|
"timestamp": event.get("timestamp"),
|
|
})
|
|
|
|
if event.get("elapsed_time"):
|
|
component = event.get("component_name", "unknown")
|
|
if component not in analysis["component_execution_times"]:
|
|
analysis["component_execution_times"][component] = []
|
|
analysis["component_execution_times"][component].append(
|
|
event.get("elapsed_time")
|
|
)
|
|
|
|
for component, times in analysis["component_execution_times"].items():
|
|
avg_time = sum(times) / len(times)
|
|
if avg_time > 5.0:
|
|
analysis["bottlenecks"].append({
|
|
"component": component,
|
|
"avg_execution_time": round(avg_time, 2),
|
|
"executions": len(times),
|
|
})
|
|
|
|
if analysis["bottlenecks"]:
|
|
analysis["recommendations"].append(
|
|
"Consider optimizing slow components or adding caching"
|
|
)
|
|
if len(analysis["errors"]) > 0:
|
|
analysis["recommendations"].append(
|
|
"Review and fix error-prone components"
|
|
)
|
|
|
|
analysis["event_distribution"] = dict(analysis["event_distribution"])
|
|
|
|
return analysis
|
|
except Exception as e:
|
|
logging.exception(f"Failed to analyze trace: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def cleanup_old_traces(days: int = 7) -> Tuple[int, str]:
|
|
"""
|
|
Clean up trace sessions older than specified days.
|
|
|
|
Args:
|
|
days: Number of days to keep traces
|
|
|
|
Returns:
|
|
Tuple of (deleted count, message)
|
|
"""
|
|
try:
|
|
cutoff = datetime.utcnow() - timedelta(days=days)
|
|
pattern = f"{TRACE_KEY_PREFIX}session:*"
|
|
keys = REDIS_CONN.keys(pattern)
|
|
|
|
deleted = 0
|
|
for key in keys:
|
|
data = REDIS_CONN.get(key)
|
|
if data:
|
|
session = json.loads(data)
|
|
created_at = session.get("created_at")
|
|
if created_at:
|
|
session_time = datetime.fromisoformat(created_at)
|
|
if session_time < cutoff:
|
|
task_id = key.split(":")[-1]
|
|
TraceService.delete_trace(task_id)
|
|
deleted += 1
|
|
|
|
return deleted, f"Deleted {deleted} old trace sessions"
|
|
except Exception as e:
|
|
logging.exception(f"Failed to cleanup traces: {e}")
|
|
return 0, str(e)
|