ragflow/api/db/services/trace_service.py
0xsatoshi99 697f8138b6 feat: add trace logging for agent completions API (Issue #10081)
Implement comprehensive trace logging system for agent execution that
returns step-by-step execution traces in API responses.

New modules:
- agent/trace/trace_models.py: Data models for trace events, sessions,
  LLM calls, retrievals, and tool calls
- agent/trace/trace_collector.py: Real-time trace event collection with
  subscriber pattern for streaming
- agent/trace/trace_formatter.py: Multiple formatters (streaming, compact,
  detailed) for different output needs
- api/db/services/trace_service.py: Service layer for trace persistence,
  retrieval, and analysis
- api/apps/trace_app.py: REST API endpoints for trace management

Features:
- Real-time trace streaming via SSE
- Multiple trace verbosity levels (minimal, standard, detailed, debug)
- Component execution timing and bottleneck detection
- LLM call tracking with token counts
- Retrieval operation logging with chunk details
- Tool call tracing with arguments and results
- Trace session persistence in Redis
- Analysis and recommendations based on trace data

API Endpoints:
- GET /traces - List trace sessions
- GET /traces/<task_id> - Get trace session
- GET /traces/<task_id>/events - Get filtered events
- GET /traces/<task_id>/summary - Get trace summary
- GET /traces/<task_id>/analysis - Analyze trace
- GET /traces/<task_id>/stream - Stream trace events
- DELETE /traces/<task_id> - Delete trace
- POST /traces/cleanup - Cleanup old traces
- POST /agents/<agent_id>/completions/trace - Completion with trace

Closes #10081
2025-12-03 18:08:14 +01:00

482 lines
16 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Service for Agent Execution Logging
This module provides the TraceService class that manages trace data persistence,
retrieval, and analysis. It integrates with the trace collector and formatter
modules to provide a complete tracing solution.
"""
import json
import logging
import time
from datetime import datetime, timedelta
from typing import Any, Optional, Tuple
from collections import defaultdict
from agent.trace.trace_models import (
TraceEventType,
TraceLevel,
TraceSession,
TraceEvent,
TraceMetadata,
)
from agent.trace.trace_collector import (
TraceCollector,
get_trace_collector,
create_trace_collector,
)
from agent.trace.trace_formatter import (
TraceFormatterFactory,
format_trace_for_api,
)
from rag.utils.redis_conn import REDIS_CONN
TRACE_KEY_PREFIX = "agent_trace:"
TRACE_SESSION_TTL = 86400 * 7
TRACE_EVENT_TTL = 86400 * 3
class TraceService:
"""
Service for managing agent execution traces.
Provides methods for creating, storing, retrieving, and analyzing
trace data from agent executions.
"""
@staticmethod
def create_trace_session(
task_id: str,
agent_id: str,
session_id: str,
user_id: str,
tenant_id: str,
trace_level: str = "standard",
) -> Tuple[bool, str]:
"""
Create a new trace session for an agent execution.
Args:
task_id: Unique identifier for the task
agent_id: ID of the agent being executed
session_id: ID of the conversation session
user_id: ID of the user
tenant_id: ID of the tenant
trace_level: Verbosity level (minimal, standard, detailed, debug)
Returns:
Tuple of (success, trace_id or error message)
"""
try:
level = TraceLevel(trace_level)
except ValueError:
level = TraceLevel.STANDARD
try:
collector = create_trace_collector(
task_id=task_id,
agent_id=agent_id,
session_id=session_id,
user_id=user_id,
tenant_id=tenant_id,
trace_level=level,
)
session_data = {
"task_id": task_id,
"agent_id": agent_id,
"session_id": session_id,
"user_id": user_id,
"tenant_id": tenant_id,
"trace_level": trace_level,
"created_at": datetime.utcnow().isoformat(),
"status": "running",
}
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_data))
return True, task_id
except Exception as e:
logging.exception(f"Failed to create trace session: {e}")
return False, str(e)
@staticmethod
def get_trace_collector(task_id: str) -> Optional[TraceCollector]:
"""Get an active trace collector by task ID."""
return get_trace_collector(task_id)
@staticmethod
def save_trace_session(task_id: str) -> Tuple[bool, str]:
"""
Save the current trace session to persistent storage.
Args:
task_id: ID of the task/trace session
Returns:
Tuple of (success, message)
"""
try:
collector = get_trace_collector(task_id)
if not collector:
return False, "Trace collector not found"
session = collector.get_session()
session_dict = session.to_dict()
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_dict))
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
events_data = json.dumps([e.to_dict() for e in session.events])
REDIS_CONN.setex(events_key, TRACE_EVENT_TTL, events_data)
return True, "Trace session saved successfully"
except Exception as e:
logging.exception(f"Failed to save trace session: {e}")
return False, str(e)
@staticmethod
def get_trace_session(task_id: str) -> Optional[dict[str, Any]]:
"""
Retrieve a trace session by task ID.
Args:
task_id: ID of the task/trace session
Returns:
Trace session data or None if not found
"""
try:
collector = get_trace_collector(task_id)
if collector:
return collector.get_session().to_dict()
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
data = REDIS_CONN.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logging.exception(f"Failed to get trace session: {e}")
return None
@staticmethod
def get_trace_events(
task_id: str,
event_types: Optional[list[str]] = None,
component_id: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> list[dict[str, Any]]:
"""
Retrieve trace events with optional filtering.
Args:
task_id: ID of the task/trace session
event_types: Filter by event types
component_id: Filter by component ID
limit: Maximum number of events to return
offset: Number of events to skip
Returns:
List of trace events
"""
try:
collector = get_trace_collector(task_id)
if collector:
events = collector.get_events()
else:
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
data = REDIS_CONN.get(events_key)
if not data:
return []
events = [TraceEvent.from_dict(e) for e in json.loads(data)]
if event_types:
type_set = set(event_types)
events = [e for e in events if e.event_type.value in type_set]
if component_id:
events = [e for e in events if e.component_id == component_id]
events = events[offset:offset + limit]
return [e.to_dict() for e in events]
except Exception as e:
logging.exception(f"Failed to get trace events: {e}")
return []
@staticmethod
def get_trace_summary(task_id: str) -> Optional[dict[str, Any]]:
"""
Get a summary of the trace session.
Args:
task_id: ID of the task/trace session
Returns:
Summary data or None if not found
"""
try:
collector = get_trace_collector(task_id)
if collector:
return collector.get_summary()
session_data = TraceService.get_trace_session(task_id)
if session_data and "summary" in session_data:
return session_data["summary"]
return None
except Exception as e:
logging.exception(f"Failed to get trace summary: {e}")
return None
@staticmethod
def format_trace(
task_id: str,
format_type: str = "streaming",
**kwargs
) -> Optional[dict[str, Any]]:
"""
Format a trace session using the specified formatter.
Args:
task_id: ID of the task/trace session
format_type: Type of formatter (streaming, compact, detailed)
**kwargs: Additional formatter options
Returns:
Formatted trace data or None if not found
"""
try:
session_data = TraceService.get_trace_session(task_id)
if not session_data:
return None
collector = get_trace_collector(task_id)
if collector:
session = collector.get_session()
return format_trace_for_api(session, format_type, **kwargs)
return session_data
except Exception as e:
logging.exception(f"Failed to format trace: {e}")
return None
@staticmethod
def list_traces(
tenant_id: str,
agent_id: Optional[str] = None,
user_id: Optional[str] = None,
status: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
page: int = 1,
page_size: int = 20,
) -> dict[str, Any]:
"""
List trace sessions with filtering and pagination.
Args:
tenant_id: ID of the tenant
agent_id: Filter by agent ID
user_id: Filter by user ID
status: Filter by status (running, completed, failed)
start_time: Filter by start time
end_time: Filter by end time
page: Page number
page_size: Items per page
Returns:
Paginated list of trace sessions
"""
try:
pattern = f"{TRACE_KEY_PREFIX}session:*"
keys = REDIS_CONN.keys(pattern)
sessions = []
for key in keys:
data = REDIS_CONN.get(key)
if data:
session = json.loads(data)
if session.get("tenant_id") == tenant_id:
sessions.append(session)
if agent_id:
sessions = [s for s in sessions if s.get("agent_id") == agent_id]
if user_id:
sessions = [s for s in sessions if s.get("user_id") == user_id]
if status:
sessions = [s for s in sessions if s.get("status") == status]
if start_time:
sessions = [s for s in sessions
if datetime.fromisoformat(s.get("created_at", "")) >= start_time]
if end_time:
sessions = [s for s in sessions
if datetime.fromisoformat(s.get("created_at", "")) <= end_time]
sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
total = len(sessions)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated = sessions[start_idx:end_idx]
return {
"total": total,
"page": page,
"page_size": page_size,
"sessions": paginated,
}
except Exception as e:
logging.exception(f"Failed to list traces: {e}")
return {"total": 0, "page": page, "page_size": page_size, "sessions": []}
@staticmethod
def delete_trace(task_id: str) -> Tuple[bool, str]:
"""
Delete a trace session.
Args:
task_id: ID of the task/trace session
Returns:
Tuple of (success, message)
"""
try:
session_key = f"{TRACE_KEY_PREFIX}session:{task_id}"
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
REDIS_CONN.delete(session_key)
REDIS_CONN.delete(events_key)
return True, "Trace deleted successfully"
except Exception as e:
logging.exception(f"Failed to delete trace: {e}")
return False, str(e)
@staticmethod
def analyze_trace(task_id: str) -> Optional[dict[str, Any]]:
"""
Analyze a trace session and provide insights.
Args:
task_id: ID of the task/trace session
Returns:
Analysis results or None if not found
"""
try:
session_data = TraceService.get_trace_session(task_id)
if not session_data:
return None
events = TraceService.get_trace_events(task_id, limit=1000)
analysis = {
"task_id": task_id,
"total_events": len(events),
"event_distribution": defaultdict(int),
"component_execution_times": {},
"bottlenecks": [],
"errors": [],
"recommendations": [],
}
for event in events:
event_type = event.get("event_type", "unknown")
analysis["event_distribution"][event_type] += 1
if event.get("error"):
analysis["errors"].append({
"component": event.get("component_name"),
"error": event.get("error"),
"timestamp": event.get("timestamp"),
})
if event.get("elapsed_time"):
component = event.get("component_name", "unknown")
if component not in analysis["component_execution_times"]:
analysis["component_execution_times"][component] = []
analysis["component_execution_times"][component].append(
event.get("elapsed_time")
)
for component, times in analysis["component_execution_times"].items():
avg_time = sum(times) / len(times)
if avg_time > 5.0:
analysis["bottlenecks"].append({
"component": component,
"avg_execution_time": round(avg_time, 2),
"executions": len(times),
})
if analysis["bottlenecks"]:
analysis["recommendations"].append(
"Consider optimizing slow components or adding caching"
)
if len(analysis["errors"]) > 0:
analysis["recommendations"].append(
"Review and fix error-prone components"
)
analysis["event_distribution"] = dict(analysis["event_distribution"])
return analysis
except Exception as e:
logging.exception(f"Failed to analyze trace: {e}")
return None
@staticmethod
def cleanup_old_traces(days: int = 7) -> Tuple[int, str]:
"""
Clean up trace sessions older than specified days.
Args:
days: Number of days to keep traces
Returns:
Tuple of (deleted count, message)
"""
try:
cutoff = datetime.utcnow() - timedelta(days=days)
pattern = f"{TRACE_KEY_PREFIX}session:*"
keys = REDIS_CONN.keys(pattern)
deleted = 0
for key in keys:
data = REDIS_CONN.get(key)
if data:
session = json.loads(data)
created_at = session.get("created_at")
if created_at:
session_time = datetime.fromisoformat(created_at)
if session_time < cutoff:
task_id = key.split(":")[-1]
TraceService.delete_trace(task_id)
deleted += 1
return deleted, f"Deleted {deleted} old trace sessions"
except Exception as e:
logging.exception(f"Failed to cleanup traces: {e}")
return 0, str(e)