feat: add trace logging for agent completions API (Issue #10081)
Implement comprehensive trace logging system for agent execution that returns step-by-step execution traces in API responses. New modules: - agent/trace/trace_models.py: Data models for trace events, sessions, LLM calls, retrievals, and tool calls - agent/trace/trace_collector.py: Real-time trace event collection with subscriber pattern for streaming - agent/trace/trace_formatter.py: Multiple formatters (streaming, compact, detailed) for different output needs - api/db/services/trace_service.py: Service layer for trace persistence, retrieval, and analysis - api/apps/trace_app.py: REST API endpoints for trace management Features: - Real-time trace streaming via SSE - Multiple trace verbosity levels (minimal, standard, detailed, debug) - Component execution timing and bottleneck detection - LLM call tracking with token counts - Retrieval operation logging with chunk details - Tool call tracing with arguments and results - Trace session persistence in Redis - Analysis and recommendations based on trace data API Endpoints: - GET /traces - List trace sessions - GET /traces/<task_id> - Get trace session - GET /traces/<task_id>/events - Get filtered events - GET /traces/<task_id>/summary - Get trace summary - GET /traces/<task_id>/analysis - Analyze trace - GET /traces/<task_id>/stream - Stream trace events - DELETE /traces/<task_id> - Delete trace - POST /traces/cleanup - Cleanup old traces - POST /agents/<agent_id>/completions/trace - Completion with trace Closes #10081
This commit is contained in:
parent
3c224c817b
commit
697f8138b6
6 changed files with 2327 additions and 0 deletions
67
agent/trace/__init__.py
Normal file
67
agent/trace/__init__.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Agent Trace Module
|
||||
|
||||
This module provides comprehensive tracing capabilities for agent execution,
|
||||
including trace models, collectors, formatters, and services.
|
||||
"""
|
||||
|
||||
from agent.trace.trace_models import (
|
||||
TraceEventType,
|
||||
TraceLevel,
|
||||
TraceMetadata,
|
||||
ComponentInfo,
|
||||
TraceEvent,
|
||||
LLMCallTrace,
|
||||
RetrievalTrace,
|
||||
ToolCallTrace,
|
||||
TraceSession,
|
||||
)
|
||||
from agent.trace.trace_collector import (
|
||||
TraceCollector,
|
||||
get_trace_collector,
|
||||
create_trace_collector,
|
||||
)
|
||||
from agent.trace.trace_formatter import (
|
||||
TraceFormatter,
|
||||
StreamingTraceFormatter,
|
||||
CompactTraceFormatter,
|
||||
DetailedTraceFormatter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"TraceEventType",
|
||||
"TraceLevel",
|
||||
"TraceMetadata",
|
||||
"ComponentInfo",
|
||||
"TraceEvent",
|
||||
"LLMCallTrace",
|
||||
"RetrievalTrace",
|
||||
"ToolCallTrace",
|
||||
"TraceSession",
|
||||
# Collector
|
||||
"TraceCollector",
|
||||
"get_trace_collector",
|
||||
"create_trace_collector",
|
||||
# Formatters
|
||||
"TraceFormatter",
|
||||
"StreamingTraceFormatter",
|
||||
"CompactTraceFormatter",
|
||||
"DetailedTraceFormatter",
|
||||
]
|
||||
510
agent/trace/trace_collector.py
Normal file
510
agent/trace/trace_collector.py
Normal file
|
|
@ -0,0 +1,510 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Trace Collector for Agent Execution
|
||||
|
||||
This module provides the TraceCollector class that captures and manages
|
||||
trace events during agent workflow execution. It supports real-time
|
||||
event streaming and session management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Optional, Generator, AsyncGenerator
|
||||
|
||||
from agent.trace.trace_models import (
|
||||
TraceEventType,
|
||||
TraceLevel,
|
||||
TraceMetadata,
|
||||
TraceEvent,
|
||||
TraceSession,
|
||||
LLMCallTrace,
|
||||
RetrievalTrace,
|
||||
ToolCallTrace,
|
||||
ComponentInfo,
|
||||
)
|
||||
|
||||
|
||||
_trace_collectors: dict[str, "TraceCollector"] = {}
|
||||
_collectors_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_trace_collector(task_id: str) -> Optional["TraceCollector"]:
|
||||
"""Get an existing trace collector by task ID."""
|
||||
with _collectors_lock:
|
||||
return _trace_collectors.get(task_id)
|
||||
|
||||
|
||||
def create_trace_collector(
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
trace_level: TraceLevel = TraceLevel.STANDARD,
|
||||
) -> "TraceCollector":
|
||||
"""Create a new trace collector for an agent execution."""
|
||||
with _collectors_lock:
|
||||
if task_id in _trace_collectors:
|
||||
return _trace_collectors[task_id]
|
||||
|
||||
collector = TraceCollector(
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
trace_level=trace_level,
|
||||
)
|
||||
_trace_collectors[task_id] = collector
|
||||
return collector
|
||||
|
||||
|
||||
def remove_trace_collector(task_id: str) -> None:
|
||||
"""Remove a trace collector from the registry."""
|
||||
with _collectors_lock:
|
||||
if task_id in _trace_collectors:
|
||||
del _trace_collectors[task_id]
|
||||
|
||||
|
||||
class TraceCollector:
|
||||
"""
|
||||
Collects and manages trace events during agent execution.
|
||||
|
||||
This class provides methods to record various types of trace events,
|
||||
manage trace sessions, and stream events in real-time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
trace_level: TraceLevel = TraceLevel.STANDARD,
|
||||
):
|
||||
"""Initialize the trace collector."""
|
||||
self.task_id = task_id
|
||||
self.trace_level = trace_level
|
||||
self._lock = threading.Lock()
|
||||
self._event_queue: asyncio.Queue[TraceEvent] = asyncio.Queue()
|
||||
self._subscribers: list[Callable[[TraceEvent], None]] = []
|
||||
self._is_active = True
|
||||
|
||||
metadata = TraceMetadata(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
trace_level=trace_level,
|
||||
)
|
||||
|
||||
self.session = TraceSession(
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._component_start_times: dict[str, float] = {}
|
||||
self._llm_call_starts: dict[str, tuple[float, LLMCallTrace]] = {}
|
||||
self._retrieval_starts: dict[str, tuple[float, RetrievalTrace]] = {}
|
||||
self._tool_call_starts: dict[str, tuple[float, ToolCallTrace]] = {}
|
||||
|
||||
def _should_trace(self, required_level: TraceLevel) -> bool:
|
||||
"""Check if the current trace level allows this event."""
|
||||
level_order = [TraceLevel.MINIMAL, TraceLevel.STANDARD, TraceLevel.DETAILED, TraceLevel.DEBUG]
|
||||
current_idx = level_order.index(self.trace_level)
|
||||
required_idx = level_order.index(required_level)
|
||||
return current_idx >= required_idx
|
||||
|
||||
def _create_event(
|
||||
self,
|
||||
event_type: TraceEventType,
|
||||
component_id: Optional[str] = None,
|
||||
component_name: Optional[str] = None,
|
||||
component_type: Optional[str] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
outputs: Optional[dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
elapsed_time: Optional[float] = None,
|
||||
thoughts: Optional[str] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> TraceEvent:
|
||||
"""Create a new trace event."""
|
||||
return TraceEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
component_id=component_id,
|
||||
component_name=component_name,
|
||||
component_type=component_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
thoughts=thoughts,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def _record_event(self, event: TraceEvent) -> None:
|
||||
"""Record an event and notify subscribers."""
|
||||
with self._lock:
|
||||
self.session.add_event(event)
|
||||
for subscriber in self._subscribers:
|
||||
try:
|
||||
subscriber(event)
|
||||
except Exception as e:
|
||||
logging.warning(f"Trace subscriber error: {e}")
|
||||
|
||||
def subscribe(self, callback: Callable[[TraceEvent], None]) -> None:
|
||||
"""Subscribe to trace events."""
|
||||
with self._lock:
|
||||
self._subscribers.append(callback)
|
||||
|
||||
def unsubscribe(self, callback: Callable[[TraceEvent], None]) -> None:
|
||||
"""Unsubscribe from trace events."""
|
||||
with self._lock:
|
||||
if callback in self._subscribers:
|
||||
self._subscribers.remove(callback)
|
||||
|
||||
def workflow_started(self, inputs: Optional[dict[str, Any]] = None) -> None:
|
||||
"""Record workflow start event."""
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.WORKFLOW_STARTED,
|
||||
inputs=inputs,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def workflow_completed(self, outputs: Optional[dict[str, Any]] = None) -> None:
|
||||
"""Record workflow completion event."""
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.WORKFLOW_COMPLETED,
|
||||
outputs=outputs,
|
||||
elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(),
|
||||
)
|
||||
self._record_event(event)
|
||||
self.session.complete()
|
||||
|
||||
def workflow_failed(self, error: str) -> None:
|
||||
"""Record workflow failure event."""
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.WORKFLOW_FAILED,
|
||||
error=error,
|
||||
elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(),
|
||||
)
|
||||
self._record_event(event)
|
||||
self.session.complete(error=error)
|
||||
|
||||
def node_started(
|
||||
self,
|
||||
component_id: str,
|
||||
component_name: str,
|
||||
component_type: str,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
thoughts: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record node start event."""
|
||||
self._component_start_times[component_id] = time.perf_counter()
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.NODE_STARTED,
|
||||
component_id=component_id,
|
||||
component_name=component_name,
|
||||
component_type=component_type,
|
||||
inputs=inputs,
|
||||
thoughts=thoughts,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def node_finished(
|
||||
self,
|
||||
component_id: str,
|
||||
component_name: str,
|
||||
component_type: str,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
outputs: Optional[dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record node completion event."""
|
||||
start_time = self._component_start_times.pop(component_id, None)
|
||||
elapsed_time = time.perf_counter() - start_time if start_time else None
|
||||
|
||||
event_type = TraceEventType.NODE_FAILED if error else TraceEventType.NODE_FINISHED
|
||||
|
||||
event = self._create_event(
|
||||
event_type=event_type,
|
||||
component_id=component_id,
|
||||
component_name=component_name,
|
||||
component_type=component_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def llm_call_started(
|
||||
self,
|
||||
call_id: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Record LLM call start."""
|
||||
if not self._should_trace(TraceLevel.DETAILED):
|
||||
return
|
||||
|
||||
llm_trace = LLMCallTrace(
|
||||
call_id=call_id,
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self._llm_call_starts[call_id] = (time.perf_counter(), llm_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.LLM_CALL_STARTED,
|
||||
metadata={"call_id": call_id, "model_name": model_name},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def llm_call_completed(
|
||||
self,
|
||||
call_id: str,
|
||||
response: str,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record LLM call completion."""
|
||||
if not self._should_trace(TraceLevel.DETAILED):
|
||||
return
|
||||
|
||||
start_data = self._llm_call_starts.pop(call_id, None)
|
||||
if start_data:
|
||||
start_time, llm_trace = start_data
|
||||
llm_trace.response = response
|
||||
llm_trace.prompt_tokens = prompt_tokens
|
||||
llm_trace.completion_tokens = completion_tokens
|
||||
llm_trace.total_tokens = prompt_tokens + completion_tokens
|
||||
llm_trace.latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
llm_trace.completed_at = datetime.utcnow()
|
||||
llm_trace.error = error
|
||||
self.session.add_llm_call(llm_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.LLM_CALL_COMPLETED,
|
||||
error=error,
|
||||
metadata={"call_id": call_id, "tokens": prompt_tokens + completion_tokens},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def retrieval_started(
|
||||
self,
|
||||
retrieval_id: str,
|
||||
query: str,
|
||||
knowledge_bases: list[str],
|
||||
top_k: int = 10,
|
||||
similarity_threshold: float = 0.0,
|
||||
rerank_enabled: bool = False,
|
||||
) -> None:
|
||||
"""Record retrieval operation start."""
|
||||
retrieval_trace = RetrievalTrace(
|
||||
retrieval_id=retrieval_id,
|
||||
query=query,
|
||||
knowledge_bases=knowledge_bases,
|
||||
top_k=top_k,
|
||||
similarity_threshold=similarity_threshold,
|
||||
rerank_enabled=rerank_enabled,
|
||||
)
|
||||
self._retrieval_starts[retrieval_id] = (time.perf_counter(), retrieval_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.RETRIEVAL_STARTED,
|
||||
metadata={"retrieval_id": retrieval_id, "query": query[:100]},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def retrieval_completed(
|
||||
self,
|
||||
retrieval_id: str,
|
||||
chunks: list[dict[str, Any]],
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record retrieval operation completion."""
|
||||
start_data = self._retrieval_starts.pop(retrieval_id, None)
|
||||
if start_data:
|
||||
start_time, retrieval_trace = start_data
|
||||
retrieval_trace.chunks = chunks
|
||||
retrieval_trace.chunks_retrieved = len(chunks)
|
||||
retrieval_trace.latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
retrieval_trace.completed_at = datetime.utcnow()
|
||||
retrieval_trace.error = error
|
||||
self.session.add_retrieval(retrieval_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.RETRIEVAL_COMPLETED,
|
||||
error=error,
|
||||
metadata={"retrieval_id": retrieval_id, "chunks_count": len(chunks)},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def tool_call_started(
|
||||
self,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
tool_type: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> None:
|
||||
"""Record tool call start."""
|
||||
tool_trace = ToolCallTrace(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
tool_type=tool_type,
|
||||
arguments=arguments,
|
||||
)
|
||||
self._tool_call_starts[call_id] = (time.perf_counter(), tool_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.TOOL_CALL_STARTED,
|
||||
metadata={"call_id": call_id, "tool_name": tool_name},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def tool_call_completed(
|
||||
self,
|
||||
call_id: str,
|
||||
result: Any,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record tool call completion."""
|
||||
start_data = self._tool_call_starts.pop(call_id, None)
|
||||
if start_data:
|
||||
start_time, tool_trace = start_data
|
||||
tool_trace.result = result
|
||||
tool_trace.latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
tool_trace.completed_at = datetime.utcnow()
|
||||
tool_trace.error = error
|
||||
self.session.add_tool_call(tool_trace)
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.TOOL_CALL_COMPLETED,
|
||||
error=error,
|
||||
metadata={"call_id": call_id},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def message_generated(
|
||||
self,
|
||||
content: str,
|
||||
component_id: Optional[str] = None,
|
||||
component_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record message generation event."""
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.MESSAGE_GENERATED,
|
||||
component_id=component_id,
|
||||
component_name=component_name,
|
||||
outputs={"content": content[:500] if len(content) > 500 else content},
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def thinking_started(self, component_id: str, thoughts: str) -> None:
|
||||
"""Record thinking/reasoning start."""
|
||||
if not self._should_trace(TraceLevel.DETAILED):
|
||||
return
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.THINKING_STARTED,
|
||||
component_id=component_id,
|
||||
thoughts=thoughts,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def thinking_completed(self, component_id: str, thoughts: str) -> None:
|
||||
"""Record thinking/reasoning completion."""
|
||||
if not self._should_trace(TraceLevel.DETAILED):
|
||||
return
|
||||
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.THINKING_COMPLETED,
|
||||
component_id=component_id,
|
||||
thoughts=thoughts,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def error_occurred(
|
||||
self,
|
||||
error: str,
|
||||
component_id: Optional[str] = None,
|
||||
component_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Record an error event."""
|
||||
event = self._create_event(
|
||||
event_type=TraceEventType.ERROR_OCCURRED,
|
||||
component_id=component_id,
|
||||
component_name=component_name,
|
||||
error=error,
|
||||
)
|
||||
self._record_event(event)
|
||||
|
||||
def get_session(self) -> TraceSession:
|
||||
"""Get the current trace session."""
|
||||
return self.session
|
||||
|
||||
def get_events(self) -> list[TraceEvent]:
|
||||
"""Get all recorded events."""
|
||||
return self.session.events
|
||||
|
||||
def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of the trace session."""
|
||||
return self.session.get_summary()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the trace collector and cleanup resources."""
|
||||
self._is_active = False
|
||||
remove_trace_collector(self.task_id)
|
||||
|
||||
@contextmanager
|
||||
def trace_component(
|
||||
self,
|
||||
component_id: str,
|
||||
component_name: str,
|
||||
component_type: str,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager for tracing component execution."""
|
||||
self.node_started(component_id, component_name, component_type, inputs)
|
||||
error = None
|
||||
outputs = None
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
raise
|
||||
finally:
|
||||
self.node_finished(
|
||||
component_id, component_name, component_type,
|
||||
inputs=inputs, outputs=outputs, error=error
|
||||
)
|
||||
416
agent/trace/trace_formatter.py
Normal file
416
agent/trace/trace_formatter.py
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Trace Formatter for Agent Execution Logs
|
||||
|
||||
This module provides various formatters for converting trace events and sessions
|
||||
into different output formats suitable for API responses, logging, and debugging.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Generator
|
||||
|
||||
from agent.trace.trace_models import (
|
||||
TraceEventType,
|
||||
TraceLevel,
|
||||
TraceEvent,
|
||||
TraceSession,
|
||||
LLMCallTrace,
|
||||
RetrievalTrace,
|
||||
ToolCallTrace,
|
||||
)
|
||||
|
||||
|
||||
class TraceFormatter(ABC):
|
||||
"""Abstract base class for trace formatters."""
|
||||
|
||||
@abstractmethod
|
||||
def format_event(self, event: TraceEvent) -> dict[str, Any]:
|
||||
"""Format a single trace event."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_session(self, session: TraceSession) -> dict[str, Any]:
|
||||
"""Format a complete trace session."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_for_stream(self, event: TraceEvent) -> str:
|
||||
"""Format an event for SSE streaming."""
|
||||
pass
|
||||
|
||||
|
||||
class StreamingTraceFormatter(TraceFormatter):
|
||||
"""Formatter optimized for real-time SSE streaming."""
|
||||
|
||||
def __init__(self, include_inputs: bool = True, include_outputs: bool = True):
|
||||
"""Initialize the streaming formatter."""
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
|
||||
def format_event(self, event: TraceEvent) -> dict[str, Any]:
|
||||
"""Format a trace event for streaming."""
|
||||
result = {
|
||||
"event_id": event.event_id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if event.component_id:
|
||||
result["component_id"] = event.component_id
|
||||
if event.component_name:
|
||||
result["component_name"] = event.component_name
|
||||
if event.component_type:
|
||||
result["component_type"] = event.component_type
|
||||
|
||||
if self.include_inputs and event.inputs:
|
||||
result["inputs"] = self._truncate_dict(event.inputs, max_length=200)
|
||||
|
||||
if self.include_outputs and event.outputs:
|
||||
result["outputs"] = self._truncate_dict(event.outputs, max_length=500)
|
||||
|
||||
if event.error:
|
||||
result["error"] = event.error
|
||||
if event.elapsed_time is not None:
|
||||
result["elapsed_time"] = round(event.elapsed_time, 3)
|
||||
if event.thoughts:
|
||||
result["thoughts"] = event.thoughts[:300] if len(event.thoughts) > 300 else event.thoughts
|
||||
|
||||
return result
|
||||
|
||||
def format_session(self, session: TraceSession) -> dict[str, Any]:
|
||||
"""Format a trace session for streaming response."""
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"status": session.status,
|
||||
"started_at": session.started_at.isoformat(),
|
||||
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
|
||||
"total_elapsed_time": round(session.total_elapsed_time, 3),
|
||||
"summary": session.get_summary(),
|
||||
"events": [self.format_event(e) for e in session.events],
|
||||
}
|
||||
|
||||
def format_for_stream(self, event: TraceEvent) -> str:
|
||||
"""Format an event as SSE data."""
|
||||
data = self.format_event(event)
|
||||
return f"data:{json.dumps({'event': 'trace', 'data': data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def _truncate_dict(self, d: dict[str, Any], max_length: int = 200) -> dict[str, Any]:
|
||||
"""Truncate string values in a dictionary."""
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, str) and len(value) > max_length:
|
||||
result[key] = value[:max_length] + "..."
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._truncate_dict(value, max_length)
|
||||
elif isinstance(value, list) and len(value) > 10:
|
||||
result[key] = value[:10] + ["..."]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
class CompactTraceFormatter(TraceFormatter):
|
||||
"""Compact formatter for minimal trace output."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the compact formatter."""
|
||||
self._event_icons = {
|
||||
TraceEventType.WORKFLOW_STARTED: "🚀",
|
||||
TraceEventType.WORKFLOW_COMPLETED: "✅",
|
||||
TraceEventType.WORKFLOW_FAILED: "❌",
|
||||
TraceEventType.NODE_STARTED: "▶️",
|
||||
TraceEventType.NODE_FINISHED: "✔️",
|
||||
TraceEventType.NODE_FAILED: "❌",
|
||||
TraceEventType.RETRIEVAL_STARTED: "🔍",
|
||||
TraceEventType.RETRIEVAL_COMPLETED: "📚",
|
||||
TraceEventType.LLM_CALL_STARTED: "🤖",
|
||||
TraceEventType.LLM_CALL_COMPLETED: "💬",
|
||||
TraceEventType.TOOL_CALL_STARTED: "🔧",
|
||||
TraceEventType.TOOL_CALL_COMPLETED: "⚙️",
|
||||
TraceEventType.MESSAGE_GENERATED: "📝",
|
||||
TraceEventType.ERROR_OCCURRED: "⚠️",
|
||||
TraceEventType.THINKING_STARTED: "💭",
|
||||
TraceEventType.THINKING_COMPLETED: "💡",
|
||||
}
|
||||
|
||||
def format_event(self, event: TraceEvent) -> dict[str, Any]:
|
||||
"""Format a trace event in compact form."""
|
||||
icon = self._event_icons.get(event.event_type, "•")
|
||||
|
||||
result = {
|
||||
"type": event.event_type.value,
|
||||
"icon": icon,
|
||||
"time": event.timestamp.strftime("%H:%M:%S.%f")[:-3],
|
||||
}
|
||||
|
||||
if event.component_name:
|
||||
result["component"] = event.component_name
|
||||
if event.elapsed_time is not None:
|
||||
result["duration_ms"] = round(event.elapsed_time * 1000, 1)
|
||||
if event.error:
|
||||
result["error"] = event.error[:100]
|
||||
|
||||
return result
|
||||
|
||||
def format_session(self, session: TraceSession) -> dict[str, Any]:
|
||||
"""Format a trace session in compact form."""
|
||||
summary = session.get_summary()
|
||||
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"status": session.status,
|
||||
"duration_s": round(session.total_elapsed_time, 2),
|
||||
"nodes": summary["nodes_executed"],
|
||||
"llm_calls": summary["total_llm_calls"],
|
||||
"retrievals": summary["total_retrievals"],
|
||||
"tool_calls": summary["total_tool_calls"],
|
||||
"tokens": summary["total_tokens"],
|
||||
"errors": summary["errors_count"],
|
||||
"timeline": [self.format_event(e) for e in session.events],
|
||||
}
|
||||
|
||||
def format_for_stream(self, event: TraceEvent) -> str:
|
||||
"""Format an event for SSE streaming in compact form."""
|
||||
data = self.format_event(event)
|
||||
return f"data:{json.dumps({'event': 'trace_compact', 'data': data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def format_timeline(self, session: TraceSession) -> list[str]:
|
||||
"""Format session as a text timeline."""
|
||||
lines = []
|
||||
for event in session.events:
|
||||
icon = self._event_icons.get(event.event_type, "•")
|
||||
time_str = event.timestamp.strftime("%H:%M:%S")
|
||||
component = event.component_name or ""
|
||||
duration = f" ({event.elapsed_time*1000:.0f}ms)" if event.elapsed_time else ""
|
||||
|
||||
line = f"{time_str} {icon} {event.event_type.value}"
|
||||
if component:
|
||||
line += f" [{component}]"
|
||||
line += duration
|
||||
if event.error:
|
||||
line += f" ERROR: {event.error[:50]}"
|
||||
|
||||
lines.append(line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
class DetailedTraceFormatter(TraceFormatter):
|
||||
"""Detailed formatter for comprehensive trace output."""
|
||||
|
||||
def __init__(self, include_raw_data: bool = False):
|
||||
"""Initialize the detailed formatter."""
|
||||
self.include_raw_data = include_raw_data
|
||||
|
||||
def format_event(self, event: TraceEvent) -> dict[str, Any]:
|
||||
"""Format a trace event with full details."""
|
||||
result = {
|
||||
"event_id": event.event_id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"timestamp_unix": event.timestamp.timestamp(),
|
||||
}
|
||||
|
||||
if event.component_id:
|
||||
result["component"] = {
|
||||
"id": event.component_id,
|
||||
"name": event.component_name,
|
||||
"type": event.component_type,
|
||||
}
|
||||
|
||||
if event.inputs is not None:
|
||||
result["inputs"] = event.inputs
|
||||
if event.outputs is not None:
|
||||
result["outputs"] = event.outputs
|
||||
if event.error:
|
||||
result["error"] = {
|
||||
"message": event.error,
|
||||
"occurred_at": event.timestamp.isoformat(),
|
||||
}
|
||||
if event.elapsed_time is not None:
|
||||
result["timing"] = {
|
||||
"elapsed_seconds": round(event.elapsed_time, 4),
|
||||
"elapsed_ms": round(event.elapsed_time * 1000, 2),
|
||||
}
|
||||
if event.thoughts:
|
||||
result["thoughts"] = event.thoughts
|
||||
if event.metadata:
|
||||
result["metadata"] = event.metadata
|
||||
|
||||
return result
|
||||
|
||||
def format_session(self, session: TraceSession) -> dict[str, Any]:
|
||||
"""Format a complete trace session with all details."""
|
||||
result = {
|
||||
"session_id": session.session_id,
|
||||
"metadata": session.metadata.to_dict(),
|
||||
"status": session.status,
|
||||
"timing": {
|
||||
"started_at": session.started_at.isoformat(),
|
||||
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
|
||||
"total_elapsed_seconds": round(session.total_elapsed_time, 4),
|
||||
},
|
||||
"summary": session.get_summary(),
|
||||
"events": [self.format_event(e) for e in session.events],
|
||||
"llm_calls": [self._format_llm_call(c) for c in session.llm_calls],
|
||||
"retrievals": [self._format_retrieval(r) for r in session.retrievals],
|
||||
"tool_calls": [self._format_tool_call(t) for t in session.tool_calls],
|
||||
}
|
||||
|
||||
if session.error:
|
||||
result["error"] = session.error
|
||||
|
||||
return result
|
||||
|
||||
def format_for_stream(self, event: TraceEvent) -> str:
|
||||
"""Format an event for SSE streaming with full details."""
|
||||
data = self.format_event(event)
|
||||
return f"data:{json.dumps({'event': 'trace_detailed', 'data': data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def _format_llm_call(self, call: LLMCallTrace) -> dict[str, Any]:
|
||||
"""Format an LLM call trace."""
|
||||
result = {
|
||||
"call_id": call.call_id,
|
||||
"model_name": call.model_name,
|
||||
"tokens": {
|
||||
"prompt": call.prompt_tokens,
|
||||
"completion": call.completion_tokens,
|
||||
"total": call.total_tokens,
|
||||
},
|
||||
"latency_ms": round(call.latency_ms, 2),
|
||||
"temperature": call.temperature,
|
||||
"started_at": call.started_at.isoformat(),
|
||||
"completed_at": call.completed_at.isoformat() if call.completed_at else None,
|
||||
}
|
||||
|
||||
if self.include_raw_data:
|
||||
result["prompt"] = call.prompt
|
||||
result["response"] = call.response
|
||||
else:
|
||||
result["prompt_preview"] = call.prompt[:200] + "..." if len(call.prompt) > 200 else call.prompt
|
||||
result["response_preview"] = call.response[:200] + "..." if call.response and len(call.response) > 200 else call.response
|
||||
|
||||
if call.max_tokens:
|
||||
result["max_tokens"] = call.max_tokens
|
||||
if call.error:
|
||||
result["error"] = call.error
|
||||
|
||||
return result
|
||||
|
||||
def _format_retrieval(self, retrieval: RetrievalTrace) -> dict[str, Any]:
|
||||
"""Format a retrieval trace."""
|
||||
result = {
|
||||
"retrieval_id": retrieval.retrieval_id,
|
||||
"query": retrieval.query,
|
||||
"knowledge_bases": retrieval.knowledge_bases,
|
||||
"config": {
|
||||
"top_k": retrieval.top_k,
|
||||
"similarity_threshold": retrieval.similarity_threshold,
|
||||
"rerank_enabled": retrieval.rerank_enabled,
|
||||
},
|
||||
"results": {
|
||||
"chunks_retrieved": retrieval.chunks_retrieved,
|
||||
"chunks_preview": retrieval.chunks[:3] if self.include_raw_data else [
|
||||
{"id": c.get("id"), "score": c.get("score")} for c in retrieval.chunks[:3]
|
||||
],
|
||||
},
|
||||
"latency_ms": round(retrieval.latency_ms, 2),
|
||||
"started_at": retrieval.started_at.isoformat(),
|
||||
"completed_at": retrieval.completed_at.isoformat() if retrieval.completed_at else None,
|
||||
}
|
||||
|
||||
if retrieval.error:
|
||||
result["error"] = retrieval.error
|
||||
|
||||
return result
|
||||
|
||||
def _format_tool_call(self, tool: ToolCallTrace) -> dict[str, Any]:
|
||||
"""Format a tool call trace."""
|
||||
result = {
|
||||
"call_id": tool.call_id,
|
||||
"tool_name": tool.tool_name,
|
||||
"tool_type": tool.tool_type,
|
||||
"arguments": tool.arguments,
|
||||
"latency_ms": round(tool.latency_ms, 2),
|
||||
"started_at": tool.started_at.isoformat(),
|
||||
"completed_at": tool.completed_at.isoformat() if tool.completed_at else None,
|
||||
}
|
||||
|
||||
if self.include_raw_data:
|
||||
result["result"] = tool.result
|
||||
else:
|
||||
result_str = str(tool.result) if tool.result else None
|
||||
result["result_preview"] = result_str[:200] + "..." if result_str and len(result_str) > 200 else result_str
|
||||
|
||||
if tool.error:
|
||||
result["error"] = tool.error
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TraceFormatterFactory:
|
||||
"""Factory for creating trace formatters."""
|
||||
|
||||
_formatters = {
|
||||
"streaming": StreamingTraceFormatter,
|
||||
"compact": CompactTraceFormatter,
|
||||
"detailed": DetailedTraceFormatter,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, format_type: str = "streaming", **kwargs) -> TraceFormatter:
|
||||
"""Create a trace formatter by type."""
|
||||
formatter_class = cls._formatters.get(format_type)
|
||||
if not formatter_class:
|
||||
raise ValueError(f"Unknown formatter type: {format_type}. Available: {list(cls._formatters.keys())}")
|
||||
return formatter_class(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, formatter_class: type) -> None:
|
||||
"""Register a custom formatter."""
|
||||
if not issubclass(formatter_class, TraceFormatter):
|
||||
raise TypeError("Formatter must be a subclass of TraceFormatter")
|
||||
cls._formatters[name] = formatter_class
|
||||
|
||||
@classmethod
|
||||
def available_formatters(cls) -> list[str]:
|
||||
"""Get list of available formatter types."""
|
||||
return list(cls._formatters.keys())
|
||||
|
||||
|
||||
def format_trace_for_api(
|
||||
session: TraceSession,
|
||||
format_type: str = "streaming",
|
||||
**kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""Convenience function to format a trace session for API response."""
|
||||
formatter = TraceFormatterFactory.create(format_type, **kwargs)
|
||||
return formatter.format_session(session)
|
||||
|
||||
|
||||
def generate_trace_stream(
|
||||
events: Generator[TraceEvent, None, None],
|
||||
format_type: str = "streaming",
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate SSE stream from trace events."""
|
||||
formatter = TraceFormatterFactory.create(format_type, **kwargs)
|
||||
for event in events:
|
||||
yield formatter.format_for_stream(event)
|
||||
375
agent/trace/trace_models.py
Normal file
375
agent/trace/trace_models.py
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Trace Models for Agent Execution Logging
|
||||
|
||||
This module provides data models for capturing and representing trace information
|
||||
during agent execution. It includes models for individual trace events, component
|
||||
execution details, and complete trace sessions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
class TraceEventType(Enum):
|
||||
"""Enumeration of trace event types."""
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_COMPLETED = "workflow_completed"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_FAILED = "node_failed"
|
||||
RETRIEVAL_STARTED = "retrieval_started"
|
||||
RETRIEVAL_COMPLETED = "retrieval_completed"
|
||||
LLM_CALL_STARTED = "llm_call_started"
|
||||
LLM_CALL_COMPLETED = "llm_call_completed"
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
MESSAGE_GENERATED = "message_generated"
|
||||
ERROR_OCCURRED = "error_occurred"
|
||||
THINKING_STARTED = "thinking_started"
|
||||
THINKING_COMPLETED = "thinking_completed"
|
||||
|
||||
|
||||
class TraceLevel(Enum):
|
||||
"""Trace verbosity levels."""
|
||||
MINIMAL = "minimal"
|
||||
STANDARD = "standard"
|
||||
DETAILED = "detailed"
|
||||
DEBUG = "debug"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceMetadata:
|
||||
"""Metadata associated with a trace session."""
|
||||
agent_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
trace_level: TraceLevel = TraceLevel.STANDARD
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
custom_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metadata to dictionary representation."""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"trace_level": self.trace_level.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"tags": self.tags,
|
||||
"custom_data": self.custom_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TraceMetadata":
|
||||
"""Create TraceMetadata from dictionary."""
|
||||
return cls(
|
||||
agent_id=data.get("agent_id", ""),
|
||||
session_id=data.get("session_id", ""),
|
||||
user_id=data.get("user_id", ""),
|
||||
tenant_id=data.get("tenant_id", ""),
|
||||
trace_level=TraceLevel(data.get("trace_level", "standard")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.utcnow(),
|
||||
tags=data.get("tags", []),
|
||||
custom_data=data.get("custom_data", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentInfo:
|
||||
"""Information about a component in the agent workflow."""
|
||||
component_id: str
|
||||
component_name: str
|
||||
component_type: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
upstream: list[str] = field(default_factory=list)
|
||||
downstream: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert component info to dictionary."""
|
||||
return {
|
||||
"component_id": self.component_id,
|
||||
"component_name": self.component_name,
|
||||
"component_type": self.component_type,
|
||||
"params": self.params,
|
||||
"upstream": self.upstream,
|
||||
"downstream": self.downstream
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo":
|
||||
"""Create ComponentInfo from dictionary."""
|
||||
return cls(
|
||||
component_id=data.get("component_id", ""),
|
||||
component_name=data.get("component_name", ""),
|
||||
component_type=data.get("component_type", ""),
|
||||
params=data.get("params", {}),
|
||||
upstream=data.get("upstream", []),
|
||||
downstream=data.get("downstream", [])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceEvent:
|
||||
"""Represents a single trace event during agent execution."""
|
||||
event_id: str
|
||||
event_type: TraceEventType
|
||||
timestamp: datetime
|
||||
component_id: Optional[str] = None
|
||||
component_name: Optional[str] = None
|
||||
component_type: Optional[str] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
thoughts: Optional[str] = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize event_id if not provided."""
|
||||
if not self.event_id:
|
||||
self.event_id = str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert trace event to dictionary representation."""
|
||||
result = {
|
||||
"event_id": self.event_id,
|
||||
"event_type": self.event_type.value,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
if self.component_id:
|
||||
result["component_id"] = self.component_id
|
||||
if self.component_name:
|
||||
result["component_name"] = self.component_name
|
||||
if self.component_type:
|
||||
result["component_type"] = self.component_type
|
||||
if self.inputs is not None:
|
||||
result["inputs"] = self.inputs
|
||||
if self.outputs is not None:
|
||||
result["outputs"] = self.outputs
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.elapsed_time is not None:
|
||||
result["elapsed_time"] = self.elapsed_time
|
||||
if self.thoughts:
|
||||
result["thoughts"] = self.thoughts
|
||||
if self.metadata:
|
||||
result["metadata"] = self.metadata
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TraceEvent":
|
||||
"""Create TraceEvent from dictionary."""
|
||||
return cls(
|
||||
event_id=data.get("event_id", ""),
|
||||
event_type=TraceEventType(data.get("event_type", "node_started")),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
|
||||
component_id=data.get("component_id"),
|
||||
component_name=data.get("component_name"),
|
||||
component_type=data.get("component_type"),
|
||||
inputs=data.get("inputs"),
|
||||
outputs=data.get("outputs"),
|
||||
error=data.get("error"),
|
||||
elapsed_time=data.get("elapsed_time"),
|
||||
thoughts=data.get("thoughts"),
|
||||
metadata=data.get("metadata", {})
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert trace event to JSON string."""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCallTrace:
|
||||
"""Trace information for LLM API calls."""
|
||||
call_id: str
|
||||
model_name: str
|
||||
prompt: str
|
||||
response: Optional[str] = None
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
latency_ms: float = 0.0
|
||||
temperature: float = 0.0
|
||||
max_tokens: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert LLM call trace to dictionary."""
|
||||
return {
|
||||
"call_id": self.call_id,
|
||||
"model_name": self.model_name,
|
||||
"prompt": self.prompt[:500] + "..." if len(self.prompt) > 500 else self.prompt,
|
||||
"response": self.response[:500] + "..." if self.response and len(self.response) > 500 else self.response,
|
||||
"prompt_tokens": self.prompt_tokens,
|
||||
"completion_tokens": self.completion_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"latency_ms": self.latency_ms,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalTrace:
|
||||
"""Trace information for retrieval operations."""
|
||||
retrieval_id: str
|
||||
query: str
|
||||
knowledge_bases: list[str] = field(default_factory=list)
|
||||
top_k: int = 10
|
||||
similarity_threshold: float = 0.0
|
||||
chunks_retrieved: int = 0
|
||||
chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
latency_ms: float = 0.0
|
||||
rerank_enabled: bool = False
|
||||
error: Optional[str] = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert retrieval trace to dictionary."""
|
||||
return {
|
||||
"retrieval_id": self.retrieval_id,
|
||||
"query": self.query,
|
||||
"knowledge_bases": self.knowledge_bases,
|
||||
"top_k": self.top_k,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"chunks_retrieved": self.chunks_retrieved,
|
||||
"chunks": self.chunks[:5],
|
||||
"latency_ms": self.latency_ms,
|
||||
"rerank_enabled": self.rerank_enabled,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallTrace:
|
||||
"""Trace information for tool/function calls."""
|
||||
call_id: str
|
||||
tool_name: str
|
||||
tool_type: str
|
||||
arguments: dict[str, Any] = field(default_factory=dict)
|
||||
result: Optional[Any] = None
|
||||
latency_ms: float = 0.0
|
||||
error: Optional[str] = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert tool call trace to dictionary."""
|
||||
result_str = str(self.result)[:500] if self.result else None
|
||||
return {
|
||||
"call_id": self.call_id,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_type": self.tool_type,
|
||||
"arguments": self.arguments,
|
||||
"result": result_str,
|
||||
"latency_ms": self.latency_ms,
|
||||
"error": self.error,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceSession:
|
||||
"""Complete trace session for an agent execution."""
|
||||
session_id: str
|
||||
metadata: TraceMetadata
|
||||
events: list[TraceEvent] = field(default_factory=list)
|
||||
llm_calls: list[LLMCallTrace] = field(default_factory=list)
|
||||
retrievals: list[RetrievalTrace] = field(default_factory=list)
|
||||
tool_calls: list[ToolCallTrace] = field(default_factory=list)
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: Optional[datetime] = None
|
||||
total_elapsed_time: float = 0.0
|
||||
status: str = "running"
|
||||
error: Optional[str] = None
|
||||
|
||||
def add_event(self, event: TraceEvent) -> None:
|
||||
"""Add a trace event to the session."""
|
||||
self.events.append(event)
|
||||
|
||||
def add_llm_call(self, llm_call: LLMCallTrace) -> None:
|
||||
"""Add an LLM call trace to the session."""
|
||||
self.llm_calls.append(llm_call)
|
||||
|
||||
def add_retrieval(self, retrieval: RetrievalTrace) -> None:
|
||||
"""Add a retrieval trace to the session."""
|
||||
self.retrievals.append(retrieval)
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallTrace) -> None:
|
||||
"""Add a tool call trace to the session."""
|
||||
self.tool_calls.append(tool_call)
|
||||
|
||||
def complete(self, error: Optional[str] = None) -> None:
|
||||
"""Mark the session as completed."""
|
||||
self.completed_at = datetime.utcnow()
|
||||
self.total_elapsed_time = (self.completed_at - self.started_at).total_seconds()
|
||||
self.status = "failed" if error else "completed"
|
||||
self.error = error
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert trace session to dictionary representation."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"metadata": self.metadata.to_dict(),
|
||||
"events": [e.to_dict() for e in self.events],
|
||||
"llm_calls": [c.to_dict() for c in self.llm_calls],
|
||||
"retrievals": [r.to_dict() for r in self.retrievals],
|
||||
"tool_calls": [t.to_dict() for t in self.tool_calls],
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"total_elapsed_time": self.total_elapsed_time,
|
||||
"status": self.status,
|
||||
"error": self.error,
|
||||
"summary": self.get_summary()
|
||||
}
|
||||
|
||||
def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of the trace session."""
|
||||
return {
|
||||
"total_events": len(self.events),
|
||||
"total_llm_calls": len(self.llm_calls),
|
||||
"total_retrievals": len(self.retrievals),
|
||||
"total_tool_calls": len(self.tool_calls),
|
||||
"total_tokens": sum(c.total_tokens for c in self.llm_calls),
|
||||
"total_chunks_retrieved": sum(r.chunks_retrieved for r in self.retrievals),
|
||||
"nodes_executed": len(set(e.component_id for e in self.events if e.component_id)),
|
||||
"errors_count": len([e for e in self.events if e.error])
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert trace session to JSON string."""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
|
||||
477
api/apps/trace_app.py
Normal file
477
api/apps/trace_app.py
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Agent Trace API Endpoints
|
||||
|
||||
Provides REST API endpoints for accessing agent execution traces,
|
||||
including trace retrieval, filtering, analysis, and management.
|
||||
This addresses Issue #10081: Add Trace Logging for Agent Completions API.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from quart import request, Response
|
||||
import json
|
||||
|
||||
from api.apps import login_required, current_user
|
||||
from api.db.services.trace_service import TraceService
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
get_request_json,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/traces', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_traces():
|
||||
"""
|
||||
List trace sessions for the current tenant.
|
||||
|
||||
Query parameters:
|
||||
- agent_id: Filter by agent ID (optional)
|
||||
- user_id: Filter by user ID (optional)
|
||||
- status: Filter by status (running, completed, failed) (optional)
|
||||
- start_time: Filter by start time ISO format (optional)
|
||||
- end_time: Filter by end time ISO format (optional)
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
|
||||
Returns:
|
||||
Paginated list of trace sessions
|
||||
"""
|
||||
try:
|
||||
agent_id = request.args.get("agent_id")
|
||||
user_id = request.args.get("user_id")
|
||||
status = request.args.get("status")
|
||||
start_time_str = request.args.get("start_time")
|
||||
end_time_str = request.args.get("end_time")
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
|
||||
start_time = None
|
||||
end_time = None
|
||||
if start_time_str:
|
||||
start_time = datetime.fromisoformat(start_time_str)
|
||||
if end_time_str:
|
||||
end_time = datetime.fromisoformat(end_time_str)
|
||||
|
||||
result = TraceService.list_traces(
|
||||
tenant_id=current_user.id,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_trace(task_id):
|
||||
"""
|
||||
Get a specific trace session by task ID.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Query parameters:
|
||||
- format: Output format (streaming, compact, detailed) (default: streaming)
|
||||
|
||||
Returns:
|
||||
Trace session data
|
||||
"""
|
||||
try:
|
||||
format_type = request.args.get("format", "streaming")
|
||||
|
||||
result = TraceService.format_trace(task_id, format_type)
|
||||
|
||||
if not result:
|
||||
return get_data_error_result(
|
||||
message="Trace session not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>/events', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_trace_events(task_id):
|
||||
"""
|
||||
Get trace events for a specific session.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Query parameters:
|
||||
- event_types: Comma-separated list of event types to filter (optional)
|
||||
- component_id: Filter by component ID (optional)
|
||||
- limit: Maximum number of events (default: 100)
|
||||
- offset: Number of events to skip (default: 0)
|
||||
|
||||
Returns:
|
||||
List of trace events
|
||||
"""
|
||||
try:
|
||||
event_types_str = request.args.get("event_types")
|
||||
event_types = event_types_str.split(",") if event_types_str else None
|
||||
component_id = request.args.get("component_id")
|
||||
limit = int(request.args.get("limit", 100))
|
||||
offset = int(request.args.get("offset", 0))
|
||||
|
||||
events = TraceService.get_trace_events(
|
||||
task_id=task_id,
|
||||
event_types=event_types,
|
||||
component_id=component_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return get_json_result(data={"events": events, "count": len(events)})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>/summary', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def get_trace_summary(task_id):
|
||||
"""
|
||||
Get a summary of a trace session.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Returns:
|
||||
Trace session summary
|
||||
"""
|
||||
try:
|
||||
summary = TraceService.get_trace_summary(task_id)
|
||||
|
||||
if not summary:
|
||||
return get_data_error_result(
|
||||
message="Trace session not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=summary)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>/analysis', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def analyze_trace(task_id):
|
||||
"""
|
||||
Analyze a trace session and get insights.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Returns:
|
||||
Analysis results including bottlenecks, errors, and recommendations
|
||||
"""
|
||||
try:
|
||||
analysis = TraceService.analyze_trace(task_id)
|
||||
|
||||
if not analysis:
|
||||
return get_data_error_result(
|
||||
message="Trace session not found or analysis failed",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
return get_json_result(data=analysis)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
async def delete_trace(task_id):
|
||||
"""
|
||||
Delete a trace session.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
success, message = TraceService.delete_trace(task_id)
|
||||
|
||||
if not success:
|
||||
return get_data_error_result(message=message)
|
||||
|
||||
return get_json_result(data={"task_id": task_id, "message": message})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/cleanup', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
async def cleanup_traces():
|
||||
"""
|
||||
Clean up old trace sessions.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"days": 7 // Number of days to keep traces (default: 7)
|
||||
}
|
||||
|
||||
Returns:
|
||||
Number of deleted traces
|
||||
"""
|
||||
try:
|
||||
req = await get_request_json()
|
||||
days = req.get("days", 7)
|
||||
|
||||
deleted, message = TraceService.cleanup_old_traces(days)
|
||||
|
||||
return get_json_result(data={"deleted": deleted, "message": message})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/traces/<task_id>/stream', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def stream_trace(task_id):
|
||||
"""
|
||||
Stream trace events in real-time using Server-Sent Events.
|
||||
|
||||
Path parameters:
|
||||
- task_id: The task/trace session ID
|
||||
|
||||
Query parameters:
|
||||
- format: Output format (streaming, compact, detailed) (default: streaming)
|
||||
|
||||
Returns:
|
||||
SSE stream of trace events
|
||||
"""
|
||||
try:
|
||||
format_type = request.args.get("format", "streaming")
|
||||
|
||||
from agent.trace.trace_collector import get_trace_collector
|
||||
from agent.trace.trace_formatter import TraceFormatterFactory
|
||||
|
||||
collector = get_trace_collector(task_id)
|
||||
|
||||
if not collector:
|
||||
return get_data_error_result(
|
||||
message="Active trace session not found",
|
||||
code=RetCode.DATA_ERROR
|
||||
)
|
||||
|
||||
formatter = TraceFormatterFactory.create(format_type)
|
||||
|
||||
async def generate():
|
||||
import asyncio
|
||||
|
||||
for event in collector.get_events():
|
||||
yield formatter.format_for_stream(event)
|
||||
|
||||
event_queue = []
|
||||
|
||||
def on_event(event):
|
||||
event_queue.append(event)
|
||||
|
||||
collector.subscribe(on_event)
|
||||
|
||||
try:
|
||||
while collector._is_active:
|
||||
while event_queue:
|
||||
event = event_queue.pop(0)
|
||||
yield formatter.format_for_stream(event)
|
||||
await asyncio.sleep(0.1)
|
||||
finally:
|
||||
collector.unsubscribe(on_event)
|
||||
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
resp = Response(generate(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/traces', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
async def list_agent_traces(agent_id):
|
||||
"""
|
||||
List trace sessions for a specific agent.
|
||||
|
||||
Path parameters:
|
||||
- agent_id: The agent ID
|
||||
|
||||
Query parameters:
|
||||
- status: Filter by status (optional)
|
||||
- page: Page number (default: 1)
|
||||
- page_size: Items per page (default: 20)
|
||||
|
||||
Returns:
|
||||
Paginated list of trace sessions for the agent
|
||||
"""
|
||||
try:
|
||||
status = request.args.get("status")
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
|
||||
result = TraceService.list_traces(
|
||||
tenant_id=current_user.id,
|
||||
agent_id=agent_id,
|
||||
status=status,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/completions/trace', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("question")
|
||||
async def agent_completion_with_trace(agent_id):
|
||||
"""
|
||||
Execute agent completion with trace logging enabled.
|
||||
|
||||
This endpoint is similar to /agents/<agent_id>/completions but includes
|
||||
trace information in the response, addressing Issue #10081.
|
||||
|
||||
Path parameters:
|
||||
- agent_id: The agent ID
|
||||
|
||||
Request body:
|
||||
{
|
||||
"question": "User question",
|
||||
"session_id": "Optional session ID",
|
||||
"stream": true,
|
||||
"trace_level": "standard", // minimal, standard, detailed, debug
|
||||
"include_trace": true
|
||||
}
|
||||
|
||||
Returns:
|
||||
Agent response with trace information
|
||||
"""
|
||||
try:
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
|
||||
req = await get_request_json()
|
||||
stream = req.get("stream", True)
|
||||
trace_level = req.get("trace_level", "standard")
|
||||
include_trace = req.get("include_trace", True)
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
task_id = get_uuid()
|
||||
|
||||
if include_trace:
|
||||
success, trace_id = TraceService.create_trace_session(
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
session_id=req.get("session_id", ""),
|
||||
user_id=current_user.id,
|
||||
tenant_id=current_user.id,
|
||||
trace_level=trace_level,
|
||||
)
|
||||
|
||||
if stream:
|
||||
async def generate():
|
||||
full_content = ""
|
||||
reference = {}
|
||||
|
||||
async for answer in agent_completion(
|
||||
tenant_id=current_user.id,
|
||||
agent_id=agent_id,
|
||||
**req
|
||||
):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
if ans["event"] == "message":
|
||||
full_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference"):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
yield answer
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if include_trace:
|
||||
TraceService.save_trace_session(task_id)
|
||||
trace_data = TraceService.format_trace(task_id, "compact")
|
||||
yield f"data:{json.dumps({'event': 'trace', 'data': trace_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
resp = Response(generate(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = None
|
||||
|
||||
async for answer in agent_completion(
|
||||
tenant_id=current_user.id,
|
||||
agent_id=agent_id,
|
||||
**req
|
||||
):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
if ans["event"] == "message":
|
||||
full_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference"):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
final_ans = ans
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if final_ans:
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
|
||||
if include_trace:
|
||||
TraceService.save_trace_session(task_id)
|
||||
trace_data = TraceService.format_trace(task_id, "compact")
|
||||
if final_ans:
|
||||
final_ans["trace"] = trace_data
|
||||
|
||||
return get_json_result(data=final_ans)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
482
api/db/services/trace_service.py
Normal file
482
api/db/services/trace_service.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Trace Service for Agent Execution Logging
|
||||
|
||||
This module provides the TraceService class that manages trace data persistence,
|
||||
retrieval, and analysis. It integrates with the trace collector and formatter
|
||||
modules to provide a complete tracing solution.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
from agent.trace.trace_models import (
|
||||
TraceEventType,
|
||||
TraceLevel,
|
||||
TraceSession,
|
||||
TraceEvent,
|
||||
TraceMetadata,
|
||||
)
|
||||
from agent.trace.trace_collector import (
|
||||
TraceCollector,
|
||||
get_trace_collector,
|
||||
create_trace_collector,
|
||||
)
|
||||
from agent.trace.trace_formatter import (
|
||||
TraceFormatterFactory,
|
||||
format_trace_for_api,
|
||||
)
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
TRACE_KEY_PREFIX = "agent_trace:"
|
||||
TRACE_SESSION_TTL = 86400 * 7
|
||||
TRACE_EVENT_TTL = 86400 * 3
|
||||
|
||||
|
||||
class TraceService:
|
||||
"""
|
||||
Service for managing agent execution traces.
|
||||
|
||||
Provides methods for creating, storing, retrieving, and analyzing
|
||||
trace data from agent executions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_trace_session(
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
trace_level: str = "standard",
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Create a new trace session for an agent execution.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
agent_id: ID of the agent being executed
|
||||
session_id: ID of the conversation session
|
||||
user_id: ID of the user
|
||||
tenant_id: ID of the tenant
|
||||
trace_level: Verbosity level (minimal, standard, detailed, debug)
|
||||
|
||||
Returns:
|
||||
Tuple of (success, trace_id or error message)
|
||||
"""
|
||||
try:
|
||||
level = TraceLevel(trace_level)
|
||||
except ValueError:
|
||||
level = TraceLevel.STANDARD
|
||||
|
||||
try:
|
||||
collector = create_trace_collector(
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
trace_level=level,
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"trace_level": trace_level,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
||||
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_data))
|
||||
|
||||
return True, task_id
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to create trace session: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def get_trace_collector(task_id: str) -> Optional[TraceCollector]:
|
||||
"""Get an active trace collector by task ID."""
|
||||
return get_trace_collector(task_id)
|
||||
|
||||
@staticmethod
|
||||
def save_trace_session(task_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Save the current trace session to persistent storage.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
try:
|
||||
collector = get_trace_collector(task_id)
|
||||
if not collector:
|
||||
return False, "Trace collector not found"
|
||||
|
||||
session = collector.get_session()
|
||||
session_dict = session.to_dict()
|
||||
|
||||
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
||||
REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_dict))
|
||||
|
||||
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
||||
events_data = json.dumps([e.to_dict() for e in session.events])
|
||||
REDIS_CONN.setex(events_key, TRACE_EVENT_TTL, events_data)
|
||||
|
||||
return True, "Trace session saved successfully"
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to save trace session: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def get_trace_session(task_id: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a trace session by task ID.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
|
||||
Returns:
|
||||
Trace session data or None if not found
|
||||
"""
|
||||
try:
|
||||
collector = get_trace_collector(task_id)
|
||||
if collector:
|
||||
return collector.get_session().to_dict()
|
||||
|
||||
key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
||||
data = REDIS_CONN.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to get trace session: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_trace_events(
|
||||
task_id: str,
|
||||
event_types: Optional[list[str]] = None,
|
||||
component_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve trace events with optional filtering.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
event_types: Filter by event types
|
||||
component_id: Filter by component ID
|
||||
limit: Maximum number of events to return
|
||||
offset: Number of events to skip
|
||||
|
||||
Returns:
|
||||
List of trace events
|
||||
"""
|
||||
try:
|
||||
collector = get_trace_collector(task_id)
|
||||
if collector:
|
||||
events = collector.get_events()
|
||||
else:
|
||||
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
||||
data = REDIS_CONN.get(events_key)
|
||||
if not data:
|
||||
return []
|
||||
events = [TraceEvent.from_dict(e) for e in json.loads(data)]
|
||||
|
||||
if event_types:
|
||||
type_set = set(event_types)
|
||||
events = [e for e in events if e.event_type.value in type_set]
|
||||
|
||||
if component_id:
|
||||
events = [e for e in events if e.component_id == component_id]
|
||||
|
||||
events = events[offset:offset + limit]
|
||||
return [e.to_dict() for e in events]
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to get trace events: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_trace_summary(task_id: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Get a summary of the trace session.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
|
||||
Returns:
|
||||
Summary data or None if not found
|
||||
"""
|
||||
try:
|
||||
collector = get_trace_collector(task_id)
|
||||
if collector:
|
||||
return collector.get_summary()
|
||||
|
||||
session_data = TraceService.get_trace_session(task_id)
|
||||
if session_data and "summary" in session_data:
|
||||
return session_data["summary"]
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to get trace summary: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def format_trace(
|
||||
task_id: str,
|
||||
format_type: str = "streaming",
|
||||
**kwargs
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Format a trace session using the specified formatter.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
format_type: Type of formatter (streaming, compact, detailed)
|
||||
**kwargs: Additional formatter options
|
||||
|
||||
Returns:
|
||||
Formatted trace data or None if not found
|
||||
"""
|
||||
try:
|
||||
session_data = TraceService.get_trace_session(task_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
collector = get_trace_collector(task_id)
|
||||
if collector:
|
||||
session = collector.get_session()
|
||||
return format_trace_for_api(session, format_type, **kwargs)
|
||||
|
||||
return session_data
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to format trace: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_traces(
|
||||
tenant_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List trace sessions with filtering and pagination.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
agent_id: Filter by agent ID
|
||||
user_id: Filter by user ID
|
||||
status: Filter by status (running, completed, failed)
|
||||
start_time: Filter by start time
|
||||
end_time: Filter by end time
|
||||
page: Page number
|
||||
page_size: Items per page
|
||||
|
||||
Returns:
|
||||
Paginated list of trace sessions
|
||||
"""
|
||||
try:
|
||||
pattern = f"{TRACE_KEY_PREFIX}session:*"
|
||||
keys = REDIS_CONN.keys(pattern)
|
||||
|
||||
sessions = []
|
||||
for key in keys:
|
||||
data = REDIS_CONN.get(key)
|
||||
if data:
|
||||
session = json.loads(data)
|
||||
if session.get("tenant_id") == tenant_id:
|
||||
sessions.append(session)
|
||||
|
||||
if agent_id:
|
||||
sessions = [s for s in sessions if s.get("agent_id") == agent_id]
|
||||
if user_id:
|
||||
sessions = [s for s in sessions if s.get("user_id") == user_id]
|
||||
if status:
|
||||
sessions = [s for s in sessions if s.get("status") == status]
|
||||
|
||||
if start_time:
|
||||
sessions = [s for s in sessions
|
||||
if datetime.fromisoformat(s.get("created_at", "")) >= start_time]
|
||||
if end_time:
|
||||
sessions = [s for s in sessions
|
||||
if datetime.fromisoformat(s.get("created_at", "")) <= end_time]
|
||||
|
||||
sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
|
||||
total = len(sessions)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated = sessions[start_idx:end_idx]
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"sessions": paginated,
|
||||
}
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to list traces: {e}")
|
||||
return {"total": 0, "page": page, "page_size": page_size, "sessions": []}
|
||||
|
||||
@staticmethod
|
||||
def delete_trace(task_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Delete a trace session.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
try:
|
||||
session_key = f"{TRACE_KEY_PREFIX}session:{task_id}"
|
||||
events_key = f"{TRACE_KEY_PREFIX}events:{task_id}"
|
||||
|
||||
REDIS_CONN.delete(session_key)
|
||||
REDIS_CONN.delete(events_key)
|
||||
|
||||
return True, "Trace deleted successfully"
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to delete trace: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def analyze_trace(task_id: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Analyze a trace session and provide insights.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task/trace session
|
||||
|
||||
Returns:
|
||||
Analysis results or None if not found
|
||||
"""
|
||||
try:
|
||||
session_data = TraceService.get_trace_session(task_id)
|
||||
if not session_data:
|
||||
return None
|
||||
|
||||
events = TraceService.get_trace_events(task_id, limit=1000)
|
||||
|
||||
analysis = {
|
||||
"task_id": task_id,
|
||||
"total_events": len(events),
|
||||
"event_distribution": defaultdict(int),
|
||||
"component_execution_times": {},
|
||||
"bottlenecks": [],
|
||||
"errors": [],
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
for event in events:
|
||||
event_type = event.get("event_type", "unknown")
|
||||
analysis["event_distribution"][event_type] += 1
|
||||
|
||||
if event.get("error"):
|
||||
analysis["errors"].append({
|
||||
"component": event.get("component_name"),
|
||||
"error": event.get("error"),
|
||||
"timestamp": event.get("timestamp"),
|
||||
})
|
||||
|
||||
if event.get("elapsed_time"):
|
||||
component = event.get("component_name", "unknown")
|
||||
if component not in analysis["component_execution_times"]:
|
||||
analysis["component_execution_times"][component] = []
|
||||
analysis["component_execution_times"][component].append(
|
||||
event.get("elapsed_time")
|
||||
)
|
||||
|
||||
for component, times in analysis["component_execution_times"].items():
|
||||
avg_time = sum(times) / len(times)
|
||||
if avg_time > 5.0:
|
||||
analysis["bottlenecks"].append({
|
||||
"component": component,
|
||||
"avg_execution_time": round(avg_time, 2),
|
||||
"executions": len(times),
|
||||
})
|
||||
|
||||
if analysis["bottlenecks"]:
|
||||
analysis["recommendations"].append(
|
||||
"Consider optimizing slow components or adding caching"
|
||||
)
|
||||
if len(analysis["errors"]) > 0:
|
||||
analysis["recommendations"].append(
|
||||
"Review and fix error-prone components"
|
||||
)
|
||||
|
||||
analysis["event_distribution"] = dict(analysis["event_distribution"])
|
||||
|
||||
return analysis
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to analyze trace: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def cleanup_old_traces(days: int = 7) -> Tuple[int, str]:
|
||||
"""
|
||||
Clean up trace sessions older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to keep traces
|
||||
|
||||
Returns:
|
||||
Tuple of (deleted count, message)
|
||||
"""
|
||||
try:
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
pattern = f"{TRACE_KEY_PREFIX}session:*"
|
||||
keys = REDIS_CONN.keys(pattern)
|
||||
|
||||
deleted = 0
|
||||
for key in keys:
|
||||
data = REDIS_CONN.get(key)
|
||||
if data:
|
||||
session = json.loads(data)
|
||||
created_at = session.get("created_at")
|
||||
if created_at:
|
||||
session_time = datetime.fromisoformat(created_at)
|
||||
if session_time < cutoff:
|
||||
task_id = key.split(":")[-1]
|
||||
TraceService.delete_trace(task_id)
|
||||
deleted += 1
|
||||
|
||||
return deleted, f"Deleted {deleted} old trace sessions"
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to cleanup traces: {e}")
|
||||
return 0, str(e)
|
||||
Loading…
Add table
Reference in a new issue