ragflow/agent/trace/trace_models.py
0xsatoshi99 697f8138b6 feat: add trace logging for agent completions API (Issue #10081)
Implement comprehensive trace logging system for agent execution that
returns step-by-step execution traces in API responses.

New modules:
- agent/trace/trace_models.py: Data models for trace events, sessions,
  LLM calls, retrievals, and tool calls
- agent/trace/trace_collector.py: Real-time trace event collection with
  subscriber pattern for streaming
- agent/trace/trace_formatter.py: Multiple formatters (streaming, compact,
  detailed) for different output needs
- api/db/services/trace_service.py: Service layer for trace persistence,
  retrieval, and analysis
- api/apps/trace_app.py: REST API endpoints for trace management

Features:
- Real-time trace streaming via SSE
- Multiple trace verbosity levels (minimal, standard, detailed, debug)
- Component execution timing and bottleneck detection
- LLM call tracking with token counts
- Retrieval operation logging with chunk details
- Tool call tracing with arguments and results
- Trace session persistence in Redis
- Analysis and recommendations based on trace data

API Endpoints:
- GET /traces - List trace sessions
- GET /traces/<task_id> - Get trace session
- GET /traces/<task_id>/events - Get filtered events
- GET /traces/<task_id>/summary - Get trace summary
- GET /traces/<task_id>/analysis - Analyze trace
- GET /traces/<task_id>/stream - Stream trace events
- DELETE /traces/<task_id> - Delete trace
- POST /traces/cleanup - Cleanup old traces
- POST /agents/<agent_id>/completions/trace - Completion with trace

Closes #10081
2025-12-03 18:08:14 +01:00

375 lines
14 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Trace Models for Agent Execution Logging
This module provides data models for capturing and representing trace information
during agent execution. It includes models for individual trace events, component
execution details, and complete trace sessions.
"""
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union
import json
import uuid
class TraceEventType(Enum):
"""Enumeration of trace event types."""
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_COMPLETED = "workflow_completed"
WORKFLOW_FAILED = "workflow_failed"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_FAILED = "node_failed"
RETRIEVAL_STARTED = "retrieval_started"
RETRIEVAL_COMPLETED = "retrieval_completed"
LLM_CALL_STARTED = "llm_call_started"
LLM_CALL_COMPLETED = "llm_call_completed"
TOOL_CALL_STARTED = "tool_call_started"
TOOL_CALL_COMPLETED = "tool_call_completed"
MESSAGE_GENERATED = "message_generated"
ERROR_OCCURRED = "error_occurred"
THINKING_STARTED = "thinking_started"
THINKING_COMPLETED = "thinking_completed"
class TraceLevel(Enum):
"""Trace verbosity levels."""
MINIMAL = "minimal"
STANDARD = "standard"
DETAILED = "detailed"
DEBUG = "debug"
@dataclass
class TraceMetadata:
"""Metadata associated with a trace session."""
agent_id: str
session_id: str
user_id: str
tenant_id: str
trace_level: TraceLevel = TraceLevel.STANDARD
created_at: datetime = field(default_factory=datetime.utcnow)
tags: list[str] = field(default_factory=list)
custom_data: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert metadata to dictionary representation."""
return {
"agent_id": self.agent_id,
"session_id": self.session_id,
"user_id": self.user_id,
"tenant_id": self.tenant_id,
"trace_level": self.trace_level.value,
"created_at": self.created_at.isoformat(),
"tags": self.tags,
"custom_data": self.custom_data
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TraceMetadata":
"""Create TraceMetadata from dictionary."""
return cls(
agent_id=data.get("agent_id", ""),
session_id=data.get("session_id", ""),
user_id=data.get("user_id", ""),
tenant_id=data.get("tenant_id", ""),
trace_level=TraceLevel(data.get("trace_level", "standard")),
created_at=datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.utcnow(),
tags=data.get("tags", []),
custom_data=data.get("custom_data", {})
)
@dataclass
class ComponentInfo:
"""Information about a component in the agent workflow."""
component_id: str
component_name: str
component_type: str
params: dict[str, Any] = field(default_factory=dict)
upstream: list[str] = field(default_factory=list)
downstream: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert component info to dictionary."""
return {
"component_id": self.component_id,
"component_name": self.component_name,
"component_type": self.component_type,
"params": self.params,
"upstream": self.upstream,
"downstream": self.downstream
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo":
"""Create ComponentInfo from dictionary."""
return cls(
component_id=data.get("component_id", ""),
component_name=data.get("component_name", ""),
component_type=data.get("component_type", ""),
params=data.get("params", {}),
upstream=data.get("upstream", []),
downstream=data.get("downstream", [])
)
@dataclass
class TraceEvent:
"""Represents a single trace event during agent execution."""
event_id: str
event_type: TraceEventType
timestamp: datetime
component_id: Optional[str] = None
component_name: Optional[str] = None
component_type: Optional[str] = None
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
error: Optional[str] = None
elapsed_time: Optional[float] = None
thoughts: Optional[str] = None
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Initialize event_id if not provided."""
if not self.event_id:
self.event_id = str(uuid.uuid4())
def to_dict(self) -> dict[str, Any]:
"""Convert trace event to dictionary representation."""
result = {
"event_id": self.event_id,
"event_type": self.event_type.value,
"timestamp": self.timestamp.isoformat(),
}
if self.component_id:
result["component_id"] = self.component_id
if self.component_name:
result["component_name"] = self.component_name
if self.component_type:
result["component_type"] = self.component_type
if self.inputs is not None:
result["inputs"] = self.inputs
if self.outputs is not None:
result["outputs"] = self.outputs
if self.error:
result["error"] = self.error
if self.elapsed_time is not None:
result["elapsed_time"] = self.elapsed_time
if self.thoughts:
result["thoughts"] = self.thoughts
if self.metadata:
result["metadata"] = self.metadata
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TraceEvent":
"""Create TraceEvent from dictionary."""
return cls(
event_id=data.get("event_id", ""),
event_type=TraceEventType(data.get("event_type", "node_started")),
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
component_id=data.get("component_id"),
component_name=data.get("component_name"),
component_type=data.get("component_type"),
inputs=data.get("inputs"),
outputs=data.get("outputs"),
error=data.get("error"),
elapsed_time=data.get("elapsed_time"),
thoughts=data.get("thoughts"),
metadata=data.get("metadata", {})
)
def to_json(self) -> str:
"""Convert trace event to JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False)
@dataclass
class LLMCallTrace:
"""Trace information for LLM API calls."""
call_id: str
model_name: str
prompt: str
response: Optional[str] = None
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
latency_ms: float = 0.0
temperature: float = 0.0
max_tokens: Optional[int] = None
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert LLM call trace to dictionary."""
return {
"call_id": self.call_id,
"model_name": self.model_name,
"prompt": self.prompt[:500] + "..." if len(self.prompt) > 500 else self.prompt,
"response": self.response[:500] + "..." if self.response and len(self.response) > 500 else self.response,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"latency_ms": self.latency_ms,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class RetrievalTrace:
"""Trace information for retrieval operations."""
retrieval_id: str
query: str
knowledge_bases: list[str] = field(default_factory=list)
top_k: int = 10
similarity_threshold: float = 0.0
chunks_retrieved: int = 0
chunks: list[dict[str, Any]] = field(default_factory=list)
latency_ms: float = 0.0
rerank_enabled: bool = False
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert retrieval trace to dictionary."""
return {
"retrieval_id": self.retrieval_id,
"query": self.query,
"knowledge_bases": self.knowledge_bases,
"top_k": self.top_k,
"similarity_threshold": self.similarity_threshold,
"chunks_retrieved": self.chunks_retrieved,
"chunks": self.chunks[:5],
"latency_ms": self.latency_ms,
"rerank_enabled": self.rerank_enabled,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class ToolCallTrace:
"""Trace information for tool/function calls."""
call_id: str
tool_name: str
tool_type: str
arguments: dict[str, Any] = field(default_factory=dict)
result: Optional[Any] = None
latency_ms: float = 0.0
error: Optional[str] = None
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
def to_dict(self) -> dict[str, Any]:
"""Convert tool call trace to dictionary."""
result_str = str(self.result)[:500] if self.result else None
return {
"call_id": self.call_id,
"tool_name": self.tool_name,
"tool_type": self.tool_type,
"arguments": self.arguments,
"result": result_str,
"latency_ms": self.latency_ms,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None
}
@dataclass
class TraceSession:
"""Complete trace session for an agent execution."""
session_id: str
metadata: TraceMetadata
events: list[TraceEvent] = field(default_factory=list)
llm_calls: list[LLMCallTrace] = field(default_factory=list)
retrievals: list[RetrievalTrace] = field(default_factory=list)
tool_calls: list[ToolCallTrace] = field(default_factory=list)
started_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
total_elapsed_time: float = 0.0
status: str = "running"
error: Optional[str] = None
def add_event(self, event: TraceEvent) -> None:
"""Add a trace event to the session."""
self.events.append(event)
def add_llm_call(self, llm_call: LLMCallTrace) -> None:
"""Add an LLM call trace to the session."""
self.llm_calls.append(llm_call)
def add_retrieval(self, retrieval: RetrievalTrace) -> None:
"""Add a retrieval trace to the session."""
self.retrievals.append(retrieval)
def add_tool_call(self, tool_call: ToolCallTrace) -> None:
"""Add a tool call trace to the session."""
self.tool_calls.append(tool_call)
def complete(self, error: Optional[str] = None) -> None:
"""Mark the session as completed."""
self.completed_at = datetime.utcnow()
self.total_elapsed_time = (self.completed_at - self.started_at).total_seconds()
self.status = "failed" if error else "completed"
self.error = error
def to_dict(self) -> dict[str, Any]:
"""Convert trace session to dictionary representation."""
return {
"session_id": self.session_id,
"metadata": self.metadata.to_dict(),
"events": [e.to_dict() for e in self.events],
"llm_calls": [c.to_dict() for c in self.llm_calls],
"retrievals": [r.to_dict() for r in self.retrievals],
"tool_calls": [t.to_dict() for t in self.tool_calls],
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"total_elapsed_time": self.total_elapsed_time,
"status": self.status,
"error": self.error,
"summary": self.get_summary()
}
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the trace session."""
return {
"total_events": len(self.events),
"total_llm_calls": len(self.llm_calls),
"total_retrievals": len(self.retrievals),
"total_tool_calls": len(self.tool_calls),
"total_tokens": sum(c.total_tokens for c in self.llm_calls),
"total_chunks_retrieved": sum(r.chunks_retrieved for r in self.retrievals),
"nodes_executed": len(set(e.component_id for e in self.events if e.component_id)),
"errors_count": len([e for e in self.events if e.error])
}
def to_json(self) -> str:
"""Convert trace session to JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)