diff --git a/agent/trace/__init__.py b/agent/trace/__init__.py new file mode 100644 index 000000000..98012e67b --- /dev/null +++ b/agent/trace/__init__.py @@ -0,0 +1,67 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Agent Trace Module + +This module provides comprehensive tracing capabilities for agent execution, +including trace models, collectors, formatters, and services. +""" + +from agent.trace.trace_models import ( + TraceEventType, + TraceLevel, + TraceMetadata, + ComponentInfo, + TraceEvent, + LLMCallTrace, + RetrievalTrace, + ToolCallTrace, + TraceSession, +) +from agent.trace.trace_collector import ( + TraceCollector, + get_trace_collector, + create_trace_collector, +) +from agent.trace.trace_formatter import ( + TraceFormatter, + StreamingTraceFormatter, + CompactTraceFormatter, + DetailedTraceFormatter, +) + +__all__ = [ + # Models + "TraceEventType", + "TraceLevel", + "TraceMetadata", + "ComponentInfo", + "TraceEvent", + "LLMCallTrace", + "RetrievalTrace", + "ToolCallTrace", + "TraceSession", + # Collector + "TraceCollector", + "get_trace_collector", + "create_trace_collector", + # Formatters + "TraceFormatter", + "StreamingTraceFormatter", + "CompactTraceFormatter", + "DetailedTraceFormatter", +] diff --git a/agent/trace/trace_collector.py b/agent/trace/trace_collector.py new file mode 100644 index 000000000..cb9e5b633 --- /dev/null +++ b/agent/trace/trace_collector.py @@ -0,0 +1,510 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Trace Collector for Agent Execution + +This module provides the TraceCollector class that captures and manages +trace events during agent workflow execution. It supports real-time +event streaming and session management. +""" + +import asyncio +import logging +import threading +import time +import uuid +from contextlib import contextmanager +from datetime import datetime +from typing import Any, Callable, Optional, Generator, AsyncGenerator + +from agent.trace.trace_models import ( + TraceEventType, + TraceLevel, + TraceMetadata, + TraceEvent, + TraceSession, + LLMCallTrace, + RetrievalTrace, + ToolCallTrace, + ComponentInfo, +) + + +_trace_collectors: dict[str, "TraceCollector"] = {} +_collectors_lock = threading.Lock() + + +def get_trace_collector(task_id: str) -> Optional["TraceCollector"]: + """Get an existing trace collector by task ID.""" + with _collectors_lock: + return _trace_collectors.get(task_id) + + +def create_trace_collector( + task_id: str, + agent_id: str, + session_id: str, + user_id: str, + tenant_id: str, + trace_level: TraceLevel = TraceLevel.STANDARD, +) -> "TraceCollector": + """Create a new trace collector for an agent execution.""" + with _collectors_lock: + if task_id in _trace_collectors: + return _trace_collectors[task_id] + + collector = TraceCollector( + task_id=task_id, + agent_id=agent_id, + session_id=session_id, + user_id=user_id, + tenant_id=tenant_id, + trace_level=trace_level, + ) + _trace_collectors[task_id] = collector + return collector + + +def remove_trace_collector(task_id: str) -> None: + """Remove a trace collector from the registry.""" + with _collectors_lock: + if task_id in _trace_collectors: + del _trace_collectors[task_id] + + +class TraceCollector: + """ + Collects and manages trace events during agent execution. + + This class provides methods to record various types of trace events, + manage trace sessions, and stream events in real-time. + """ + + def __init__( + self, + task_id: str, + agent_id: str, + session_id: str, + user_id: str, + tenant_id: str, + trace_level: TraceLevel = TraceLevel.STANDARD, + ): + """Initialize the trace collector.""" + self.task_id = task_id + self.trace_level = trace_level + self._lock = threading.Lock() + self._event_queue: asyncio.Queue[TraceEvent] = asyncio.Queue() + self._subscribers: list[Callable[[TraceEvent], None]] = [] + self._is_active = True + + metadata = TraceMetadata( + agent_id=agent_id, + session_id=session_id, + user_id=user_id, + tenant_id=tenant_id, + trace_level=trace_level, + ) + + self.session = TraceSession( + session_id=session_id, + metadata=metadata, + ) + + self._component_start_times: dict[str, float] = {} + self._llm_call_starts: dict[str, tuple[float, LLMCallTrace]] = {} + self._retrieval_starts: dict[str, tuple[float, RetrievalTrace]] = {} + self._tool_call_starts: dict[str, tuple[float, ToolCallTrace]] = {} + + def _should_trace(self, required_level: TraceLevel) -> bool: + """Check if the current trace level allows this event.""" + level_order = [TraceLevel.MINIMAL, TraceLevel.STANDARD, TraceLevel.DETAILED, TraceLevel.DEBUG] + current_idx = level_order.index(self.trace_level) + required_idx = level_order.index(required_level) + return current_idx >= required_idx + + def _create_event( + self, + event_type: TraceEventType, + component_id: Optional[str] = None, + component_name: Optional[str] = None, + component_type: Optional[str] = None, + inputs: Optional[dict[str, Any]] = None, + outputs: Optional[dict[str, Any]] = None, + error: Optional[str] = None, + elapsed_time: Optional[float] = None, + thoughts: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> TraceEvent: + """Create a new trace event.""" + return TraceEvent( + event_id=str(uuid.uuid4()), + event_type=event_type, + timestamp=datetime.utcnow(), + component_id=component_id, + component_name=component_name, + component_type=component_type, + inputs=inputs, + outputs=outputs, + error=error, + elapsed_time=elapsed_time, + thoughts=thoughts, + metadata=metadata or {}, + ) + + def _record_event(self, event: TraceEvent) -> None: + """Record an event and notify subscribers.""" + with self._lock: + self.session.add_event(event) + for subscriber in self._subscribers: + try: + subscriber(event) + except Exception as e: + logging.warning(f"Trace subscriber error: {e}") + + def subscribe(self, callback: Callable[[TraceEvent], None]) -> None: + """Subscribe to trace events.""" + with self._lock: + self._subscribers.append(callback) + + def unsubscribe(self, callback: Callable[[TraceEvent], None]) -> None: + """Unsubscribe from trace events.""" + with self._lock: + if callback in self._subscribers: + self._subscribers.remove(callback) + + def workflow_started(self, inputs: Optional[dict[str, Any]] = None) -> None: + """Record workflow start event.""" + event = self._create_event( + event_type=TraceEventType.WORKFLOW_STARTED, + inputs=inputs, + ) + self._record_event(event) + + def workflow_completed(self, outputs: Optional[dict[str, Any]] = None) -> None: + """Record workflow completion event.""" + event = self._create_event( + event_type=TraceEventType.WORKFLOW_COMPLETED, + outputs=outputs, + elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(), + ) + self._record_event(event) + self.session.complete() + + def workflow_failed(self, error: str) -> None: + """Record workflow failure event.""" + event = self._create_event( + event_type=TraceEventType.WORKFLOW_FAILED, + error=error, + elapsed_time=(datetime.utcnow() - self.session.started_at).total_seconds(), + ) + self._record_event(event) + self.session.complete(error=error) + + def node_started( + self, + component_id: str, + component_name: str, + component_type: str, + inputs: Optional[dict[str, Any]] = None, + thoughts: Optional[str] = None, + ) -> None: + """Record node start event.""" + self._component_start_times[component_id] = time.perf_counter() + + event = self._create_event( + event_type=TraceEventType.NODE_STARTED, + component_id=component_id, + component_name=component_name, + component_type=component_type, + inputs=inputs, + thoughts=thoughts, + ) + self._record_event(event) + + def node_finished( + self, + component_id: str, + component_name: str, + component_type: str, + inputs: Optional[dict[str, Any]] = None, + outputs: Optional[dict[str, Any]] = None, + error: Optional[str] = None, + ) -> None: + """Record node completion event.""" + start_time = self._component_start_times.pop(component_id, None) + elapsed_time = time.perf_counter() - start_time if start_time else None + + event_type = TraceEventType.NODE_FAILED if error else TraceEventType.NODE_FINISHED + + event = self._create_event( + event_type=event_type, + component_id=component_id, + component_name=component_name, + component_type=component_type, + inputs=inputs, + outputs=outputs, + error=error, + elapsed_time=elapsed_time, + ) + self._record_event(event) + + def llm_call_started( + self, + call_id: str, + model_name: str, + prompt: str, + temperature: float = 0.0, + max_tokens: Optional[int] = None, + ) -> None: + """Record LLM call start.""" + if not self._should_trace(TraceLevel.DETAILED): + return + + llm_trace = LLMCallTrace( + call_id=call_id, + model_name=model_name, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + ) + self._llm_call_starts[call_id] = (time.perf_counter(), llm_trace) + + event = self._create_event( + event_type=TraceEventType.LLM_CALL_STARTED, + metadata={"call_id": call_id, "model_name": model_name}, + ) + self._record_event(event) + + def llm_call_completed( + self, + call_id: str, + response: str, + prompt_tokens: int = 0, + completion_tokens: int = 0, + error: Optional[str] = None, + ) -> None: + """Record LLM call completion.""" + if not self._should_trace(TraceLevel.DETAILED): + return + + start_data = self._llm_call_starts.pop(call_id, None) + if start_data: + start_time, llm_trace = start_data + llm_trace.response = response + llm_trace.prompt_tokens = prompt_tokens + llm_trace.completion_tokens = completion_tokens + llm_trace.total_tokens = prompt_tokens + completion_tokens + llm_trace.latency_ms = (time.perf_counter() - start_time) * 1000 + llm_trace.completed_at = datetime.utcnow() + llm_trace.error = error + self.session.add_llm_call(llm_trace) + + event = self._create_event( + event_type=TraceEventType.LLM_CALL_COMPLETED, + error=error, + metadata={"call_id": call_id, "tokens": prompt_tokens + completion_tokens}, + ) + self._record_event(event) + + def retrieval_started( + self, + retrieval_id: str, + query: str, + knowledge_bases: list[str], + top_k: int = 10, + similarity_threshold: float = 0.0, + rerank_enabled: bool = False, + ) -> None: + """Record retrieval operation start.""" + retrieval_trace = RetrievalTrace( + retrieval_id=retrieval_id, + query=query, + knowledge_bases=knowledge_bases, + top_k=top_k, + similarity_threshold=similarity_threshold, + rerank_enabled=rerank_enabled, + ) + self._retrieval_starts[retrieval_id] = (time.perf_counter(), retrieval_trace) + + event = self._create_event( + event_type=TraceEventType.RETRIEVAL_STARTED, + metadata={"retrieval_id": retrieval_id, "query": query[:100]}, + ) + self._record_event(event) + + def retrieval_completed( + self, + retrieval_id: str, + chunks: list[dict[str, Any]], + error: Optional[str] = None, + ) -> None: + """Record retrieval operation completion.""" + start_data = self._retrieval_starts.pop(retrieval_id, None) + if start_data: + start_time, retrieval_trace = start_data + retrieval_trace.chunks = chunks + retrieval_trace.chunks_retrieved = len(chunks) + retrieval_trace.latency_ms = (time.perf_counter() - start_time) * 1000 + retrieval_trace.completed_at = datetime.utcnow() + retrieval_trace.error = error + self.session.add_retrieval(retrieval_trace) + + event = self._create_event( + event_type=TraceEventType.RETRIEVAL_COMPLETED, + error=error, + metadata={"retrieval_id": retrieval_id, "chunks_count": len(chunks)}, + ) + self._record_event(event) + + def tool_call_started( + self, + call_id: str, + tool_name: str, + tool_type: str, + arguments: dict[str, Any], + ) -> None: + """Record tool call start.""" + tool_trace = ToolCallTrace( + call_id=call_id, + tool_name=tool_name, + tool_type=tool_type, + arguments=arguments, + ) + self._tool_call_starts[call_id] = (time.perf_counter(), tool_trace) + + event = self._create_event( + event_type=TraceEventType.TOOL_CALL_STARTED, + metadata={"call_id": call_id, "tool_name": tool_name}, + ) + self._record_event(event) + + def tool_call_completed( + self, + call_id: str, + result: Any, + error: Optional[str] = None, + ) -> None: + """Record tool call completion.""" + start_data = self._tool_call_starts.pop(call_id, None) + if start_data: + start_time, tool_trace = start_data + tool_trace.result = result + tool_trace.latency_ms = (time.perf_counter() - start_time) * 1000 + tool_trace.completed_at = datetime.utcnow() + tool_trace.error = error + self.session.add_tool_call(tool_trace) + + event = self._create_event( + event_type=TraceEventType.TOOL_CALL_COMPLETED, + error=error, + metadata={"call_id": call_id}, + ) + self._record_event(event) + + def message_generated( + self, + content: str, + component_id: Optional[str] = None, + component_name: Optional[str] = None, + ) -> None: + """Record message generation event.""" + event = self._create_event( + event_type=TraceEventType.MESSAGE_GENERATED, + component_id=component_id, + component_name=component_name, + outputs={"content": content[:500] if len(content) > 500 else content}, + ) + self._record_event(event) + + def thinking_started(self, component_id: str, thoughts: str) -> None: + """Record thinking/reasoning start.""" + if not self._should_trace(TraceLevel.DETAILED): + return + + event = self._create_event( + event_type=TraceEventType.THINKING_STARTED, + component_id=component_id, + thoughts=thoughts, + ) + self._record_event(event) + + def thinking_completed(self, component_id: str, thoughts: str) -> None: + """Record thinking/reasoning completion.""" + if not self._should_trace(TraceLevel.DETAILED): + return + + event = self._create_event( + event_type=TraceEventType.THINKING_COMPLETED, + component_id=component_id, + thoughts=thoughts, + ) + self._record_event(event) + + def error_occurred( + self, + error: str, + component_id: Optional[str] = None, + component_name: Optional[str] = None, + ) -> None: + """Record an error event.""" + event = self._create_event( + event_type=TraceEventType.ERROR_OCCURRED, + component_id=component_id, + component_name=component_name, + error=error, + ) + self._record_event(event) + + def get_session(self) -> TraceSession: + """Get the current trace session.""" + return self.session + + def get_events(self) -> list[TraceEvent]: + """Get all recorded events.""" + return self.session.events + + def get_summary(self) -> dict[str, Any]: + """Get a summary of the trace session.""" + return self.session.get_summary() + + def close(self) -> None: + """Close the trace collector and cleanup resources.""" + self._is_active = False + remove_trace_collector(self.task_id) + + @contextmanager + def trace_component( + self, + component_id: str, + component_name: str, + component_type: str, + inputs: Optional[dict[str, Any]] = None, + ) -> Generator[None, None, None]: + """Context manager for tracing component execution.""" + self.node_started(component_id, component_name, component_type, inputs) + error = None + outputs = None + try: + yield + except Exception as e: + error = str(e) + raise + finally: + self.node_finished( + component_id, component_name, component_type, + inputs=inputs, outputs=outputs, error=error + ) diff --git a/agent/trace/trace_formatter.py b/agent/trace/trace_formatter.py new file mode 100644 index 000000000..9c5414437 --- /dev/null +++ b/agent/trace/trace_formatter.py @@ -0,0 +1,416 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Trace Formatter for Agent Execution Logs + +This module provides various formatters for converting trace events and sessions +into different output formats suitable for API responses, logging, and debugging. +""" + +import json +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Optional, Generator + +from agent.trace.trace_models import ( + TraceEventType, + TraceLevel, + TraceEvent, + TraceSession, + LLMCallTrace, + RetrievalTrace, + ToolCallTrace, +) + + +class TraceFormatter(ABC): + """Abstract base class for trace formatters.""" + + @abstractmethod + def format_event(self, event: TraceEvent) -> dict[str, Any]: + """Format a single trace event.""" + pass + + @abstractmethod + def format_session(self, session: TraceSession) -> dict[str, Any]: + """Format a complete trace session.""" + pass + + @abstractmethod + def format_for_stream(self, event: TraceEvent) -> str: + """Format an event for SSE streaming.""" + pass + + +class StreamingTraceFormatter(TraceFormatter): + """Formatter optimized for real-time SSE streaming.""" + + def __init__(self, include_inputs: bool = True, include_outputs: bool = True): + """Initialize the streaming formatter.""" + self.include_inputs = include_inputs + self.include_outputs = include_outputs + + def format_event(self, event: TraceEvent) -> dict[str, Any]: + """Format a trace event for streaming.""" + result = { + "event_id": event.event_id, + "event_type": event.event_type.value, + "timestamp": event.timestamp.isoformat(), + } + + if event.component_id: + result["component_id"] = event.component_id + if event.component_name: + result["component_name"] = event.component_name + if event.component_type: + result["component_type"] = event.component_type + + if self.include_inputs and event.inputs: + result["inputs"] = self._truncate_dict(event.inputs, max_length=200) + + if self.include_outputs and event.outputs: + result["outputs"] = self._truncate_dict(event.outputs, max_length=500) + + if event.error: + result["error"] = event.error + if event.elapsed_time is not None: + result["elapsed_time"] = round(event.elapsed_time, 3) + if event.thoughts: + result["thoughts"] = event.thoughts[:300] if len(event.thoughts) > 300 else event.thoughts + + return result + + def format_session(self, session: TraceSession) -> dict[str, Any]: + """Format a trace session for streaming response.""" + return { + "session_id": session.session_id, + "status": session.status, + "started_at": session.started_at.isoformat(), + "completed_at": session.completed_at.isoformat() if session.completed_at else None, + "total_elapsed_time": round(session.total_elapsed_time, 3), + "summary": session.get_summary(), + "events": [self.format_event(e) for e in session.events], + } + + def format_for_stream(self, event: TraceEvent) -> str: + """Format an event as SSE data.""" + data = self.format_event(event) + return f"data:{json.dumps({'event': 'trace', 'data': data}, ensure_ascii=False)}\n\n" + + def _truncate_dict(self, d: dict[str, Any], max_length: int = 200) -> dict[str, Any]: + """Truncate string values in a dictionary.""" + result = {} + for key, value in d.items(): + if isinstance(value, str) and len(value) > max_length: + result[key] = value[:max_length] + "..." + elif isinstance(value, dict): + result[key] = self._truncate_dict(value, max_length) + elif isinstance(value, list) and len(value) > 10: + result[key] = value[:10] + ["..."] + else: + result[key] = value + return result + + +class CompactTraceFormatter(TraceFormatter): + """Compact formatter for minimal trace output.""" + + def __init__(self): + """Initialize the compact formatter.""" + self._event_icons = { + TraceEventType.WORKFLOW_STARTED: "🚀", + TraceEventType.WORKFLOW_COMPLETED: "✅", + TraceEventType.WORKFLOW_FAILED: "❌", + TraceEventType.NODE_STARTED: "▶️", + TraceEventType.NODE_FINISHED: "✔️", + TraceEventType.NODE_FAILED: "❌", + TraceEventType.RETRIEVAL_STARTED: "🔍", + TraceEventType.RETRIEVAL_COMPLETED: "📚", + TraceEventType.LLM_CALL_STARTED: "🤖", + TraceEventType.LLM_CALL_COMPLETED: "💬", + TraceEventType.TOOL_CALL_STARTED: "🔧", + TraceEventType.TOOL_CALL_COMPLETED: "⚙️", + TraceEventType.MESSAGE_GENERATED: "📝", + TraceEventType.ERROR_OCCURRED: "⚠️", + TraceEventType.THINKING_STARTED: "💭", + TraceEventType.THINKING_COMPLETED: "💡", + } + + def format_event(self, event: TraceEvent) -> dict[str, Any]: + """Format a trace event in compact form.""" + icon = self._event_icons.get(event.event_type, "•") + + result = { + "type": event.event_type.value, + "icon": icon, + "time": event.timestamp.strftime("%H:%M:%S.%f")[:-3], + } + + if event.component_name: + result["component"] = event.component_name + if event.elapsed_time is not None: + result["duration_ms"] = round(event.elapsed_time * 1000, 1) + if event.error: + result["error"] = event.error[:100] + + return result + + def format_session(self, session: TraceSession) -> dict[str, Any]: + """Format a trace session in compact form.""" + summary = session.get_summary() + + return { + "session_id": session.session_id, + "status": session.status, + "duration_s": round(session.total_elapsed_time, 2), + "nodes": summary["nodes_executed"], + "llm_calls": summary["total_llm_calls"], + "retrievals": summary["total_retrievals"], + "tool_calls": summary["total_tool_calls"], + "tokens": summary["total_tokens"], + "errors": summary["errors_count"], + "timeline": [self.format_event(e) for e in session.events], + } + + def format_for_stream(self, event: TraceEvent) -> str: + """Format an event for SSE streaming in compact form.""" + data = self.format_event(event) + return f"data:{json.dumps({'event': 'trace_compact', 'data': data}, ensure_ascii=False)}\n\n" + + def format_timeline(self, session: TraceSession) -> list[str]: + """Format session as a text timeline.""" + lines = [] + for event in session.events: + icon = self._event_icons.get(event.event_type, "•") + time_str = event.timestamp.strftime("%H:%M:%S") + component = event.component_name or "" + duration = f" ({event.elapsed_time*1000:.0f}ms)" if event.elapsed_time else "" + + line = f"{time_str} {icon} {event.event_type.value}" + if component: + line += f" [{component}]" + line += duration + if event.error: + line += f" ERROR: {event.error[:50]}" + + lines.append(line) + + return lines + + +class DetailedTraceFormatter(TraceFormatter): + """Detailed formatter for comprehensive trace output.""" + + def __init__(self, include_raw_data: bool = False): + """Initialize the detailed formatter.""" + self.include_raw_data = include_raw_data + + def format_event(self, event: TraceEvent) -> dict[str, Any]: + """Format a trace event with full details.""" + result = { + "event_id": event.event_id, + "event_type": event.event_type.value, + "timestamp": event.timestamp.isoformat(), + "timestamp_unix": event.timestamp.timestamp(), + } + + if event.component_id: + result["component"] = { + "id": event.component_id, + "name": event.component_name, + "type": event.component_type, + } + + if event.inputs is not None: + result["inputs"] = event.inputs + if event.outputs is not None: + result["outputs"] = event.outputs + if event.error: + result["error"] = { + "message": event.error, + "occurred_at": event.timestamp.isoformat(), + } + if event.elapsed_time is not None: + result["timing"] = { + "elapsed_seconds": round(event.elapsed_time, 4), + "elapsed_ms": round(event.elapsed_time * 1000, 2), + } + if event.thoughts: + result["thoughts"] = event.thoughts + if event.metadata: + result["metadata"] = event.metadata + + return result + + def format_session(self, session: TraceSession) -> dict[str, Any]: + """Format a complete trace session with all details.""" + result = { + "session_id": session.session_id, + "metadata": session.metadata.to_dict(), + "status": session.status, + "timing": { + "started_at": session.started_at.isoformat(), + "completed_at": session.completed_at.isoformat() if session.completed_at else None, + "total_elapsed_seconds": round(session.total_elapsed_time, 4), + }, + "summary": session.get_summary(), + "events": [self.format_event(e) for e in session.events], + "llm_calls": [self._format_llm_call(c) for c in session.llm_calls], + "retrievals": [self._format_retrieval(r) for r in session.retrievals], + "tool_calls": [self._format_tool_call(t) for t in session.tool_calls], + } + + if session.error: + result["error"] = session.error + + return result + + def format_for_stream(self, event: TraceEvent) -> str: + """Format an event for SSE streaming with full details.""" + data = self.format_event(event) + return f"data:{json.dumps({'event': 'trace_detailed', 'data': data}, ensure_ascii=False)}\n\n" + + def _format_llm_call(self, call: LLMCallTrace) -> dict[str, Any]: + """Format an LLM call trace.""" + result = { + "call_id": call.call_id, + "model_name": call.model_name, + "tokens": { + "prompt": call.prompt_tokens, + "completion": call.completion_tokens, + "total": call.total_tokens, + }, + "latency_ms": round(call.latency_ms, 2), + "temperature": call.temperature, + "started_at": call.started_at.isoformat(), + "completed_at": call.completed_at.isoformat() if call.completed_at else None, + } + + if self.include_raw_data: + result["prompt"] = call.prompt + result["response"] = call.response + else: + result["prompt_preview"] = call.prompt[:200] + "..." if len(call.prompt) > 200 else call.prompt + result["response_preview"] = call.response[:200] + "..." if call.response and len(call.response) > 200 else call.response + + if call.max_tokens: + result["max_tokens"] = call.max_tokens + if call.error: + result["error"] = call.error + + return result + + def _format_retrieval(self, retrieval: RetrievalTrace) -> dict[str, Any]: + """Format a retrieval trace.""" + result = { + "retrieval_id": retrieval.retrieval_id, + "query": retrieval.query, + "knowledge_bases": retrieval.knowledge_bases, + "config": { + "top_k": retrieval.top_k, + "similarity_threshold": retrieval.similarity_threshold, + "rerank_enabled": retrieval.rerank_enabled, + }, + "results": { + "chunks_retrieved": retrieval.chunks_retrieved, + "chunks_preview": retrieval.chunks[:3] if self.include_raw_data else [ + {"id": c.get("id"), "score": c.get("score")} for c in retrieval.chunks[:3] + ], + }, + "latency_ms": round(retrieval.latency_ms, 2), + "started_at": retrieval.started_at.isoformat(), + "completed_at": retrieval.completed_at.isoformat() if retrieval.completed_at else None, + } + + if retrieval.error: + result["error"] = retrieval.error + + return result + + def _format_tool_call(self, tool: ToolCallTrace) -> dict[str, Any]: + """Format a tool call trace.""" + result = { + "call_id": tool.call_id, + "tool_name": tool.tool_name, + "tool_type": tool.tool_type, + "arguments": tool.arguments, + "latency_ms": round(tool.latency_ms, 2), + "started_at": tool.started_at.isoformat(), + "completed_at": tool.completed_at.isoformat() if tool.completed_at else None, + } + + if self.include_raw_data: + result["result"] = tool.result + else: + result_str = str(tool.result) if tool.result else None + result["result_preview"] = result_str[:200] + "..." if result_str and len(result_str) > 200 else result_str + + if tool.error: + result["error"] = tool.error + + return result + + +class TraceFormatterFactory: + """Factory for creating trace formatters.""" + + _formatters = { + "streaming": StreamingTraceFormatter, + "compact": CompactTraceFormatter, + "detailed": DetailedTraceFormatter, + } + + @classmethod + def create(cls, format_type: str = "streaming", **kwargs) -> TraceFormatter: + """Create a trace formatter by type.""" + formatter_class = cls._formatters.get(format_type) + if not formatter_class: + raise ValueError(f"Unknown formatter type: {format_type}. Available: {list(cls._formatters.keys())}") + return formatter_class(**kwargs) + + @classmethod + def register(cls, name: str, formatter_class: type) -> None: + """Register a custom formatter.""" + if not issubclass(formatter_class, TraceFormatter): + raise TypeError("Formatter must be a subclass of TraceFormatter") + cls._formatters[name] = formatter_class + + @classmethod + def available_formatters(cls) -> list[str]: + """Get list of available formatter types.""" + return list(cls._formatters.keys()) + + +def format_trace_for_api( + session: TraceSession, + format_type: str = "streaming", + **kwargs +) -> dict[str, Any]: + """Convenience function to format a trace session for API response.""" + formatter = TraceFormatterFactory.create(format_type, **kwargs) + return formatter.format_session(session) + + +def generate_trace_stream( + events: Generator[TraceEvent, None, None], + format_type: str = "streaming", + **kwargs +) -> Generator[str, None, None]: + """Generate SSE stream from trace events.""" + formatter = TraceFormatterFactory.create(format_type, **kwargs) + for event in events: + yield formatter.format_for_stream(event) diff --git a/agent/trace/trace_models.py b/agent/trace/trace_models.py new file mode 100644 index 000000000..286c96175 --- /dev/null +++ b/agent/trace/trace_models.py @@ -0,0 +1,375 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Trace Models for Agent Execution Logging + +This module provides data models for capturing and representing trace information +during agent execution. It includes models for individual trace events, component +execution details, and complete trace sessions. +""" + +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +from typing import Any, Optional, Union +import json +import uuid + + +class TraceEventType(Enum): + """Enumeration of trace event types.""" + WORKFLOW_STARTED = "workflow_started" + WORKFLOW_COMPLETED = "workflow_completed" + WORKFLOW_FAILED = "workflow_failed" + NODE_STARTED = "node_started" + NODE_FINISHED = "node_finished" + NODE_FAILED = "node_failed" + RETRIEVAL_STARTED = "retrieval_started" + RETRIEVAL_COMPLETED = "retrieval_completed" + LLM_CALL_STARTED = "llm_call_started" + LLM_CALL_COMPLETED = "llm_call_completed" + TOOL_CALL_STARTED = "tool_call_started" + TOOL_CALL_COMPLETED = "tool_call_completed" + MESSAGE_GENERATED = "message_generated" + ERROR_OCCURRED = "error_occurred" + THINKING_STARTED = "thinking_started" + THINKING_COMPLETED = "thinking_completed" + + +class TraceLevel(Enum): + """Trace verbosity levels.""" + MINIMAL = "minimal" + STANDARD = "standard" + DETAILED = "detailed" + DEBUG = "debug" + + +@dataclass +class TraceMetadata: + """Metadata associated with a trace session.""" + agent_id: str + session_id: str + user_id: str + tenant_id: str + trace_level: TraceLevel = TraceLevel.STANDARD + created_at: datetime = field(default_factory=datetime.utcnow) + tags: list[str] = field(default_factory=list) + custom_data: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert metadata to dictionary representation.""" + return { + "agent_id": self.agent_id, + "session_id": self.session_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "trace_level": self.trace_level.value, + "created_at": self.created_at.isoformat(), + "tags": self.tags, + "custom_data": self.custom_data + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TraceMetadata": + """Create TraceMetadata from dictionary.""" + return cls( + agent_id=data.get("agent_id", ""), + session_id=data.get("session_id", ""), + user_id=data.get("user_id", ""), + tenant_id=data.get("tenant_id", ""), + trace_level=TraceLevel(data.get("trace_level", "standard")), + created_at=datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.utcnow(), + tags=data.get("tags", []), + custom_data=data.get("custom_data", {}) + ) + + +@dataclass +class ComponentInfo: + """Information about a component in the agent workflow.""" + component_id: str + component_name: str + component_type: str + params: dict[str, Any] = field(default_factory=dict) + upstream: list[str] = field(default_factory=list) + downstream: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Convert component info to dictionary.""" + return { + "component_id": self.component_id, + "component_name": self.component_name, + "component_type": self.component_type, + "params": self.params, + "upstream": self.upstream, + "downstream": self.downstream + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ComponentInfo": + """Create ComponentInfo from dictionary.""" + return cls( + component_id=data.get("component_id", ""), + component_name=data.get("component_name", ""), + component_type=data.get("component_type", ""), + params=data.get("params", {}), + upstream=data.get("upstream", []), + downstream=data.get("downstream", []) + ) + + +@dataclass +class TraceEvent: + """Represents a single trace event during agent execution.""" + event_id: str + event_type: TraceEventType + timestamp: datetime + component_id: Optional[str] = None + component_name: Optional[str] = None + component_type: Optional[str] = None + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + error: Optional[str] = None + elapsed_time: Optional[float] = None + thoughts: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Initialize event_id if not provided.""" + if not self.event_id: + self.event_id = str(uuid.uuid4()) + + def to_dict(self) -> dict[str, Any]: + """Convert trace event to dictionary representation.""" + result = { + "event_id": self.event_id, + "event_type": self.event_type.value, + "timestamp": self.timestamp.isoformat(), + } + if self.component_id: + result["component_id"] = self.component_id + if self.component_name: + result["component_name"] = self.component_name + if self.component_type: + result["component_type"] = self.component_type + if self.inputs is not None: + result["inputs"] = self.inputs + if self.outputs is not None: + result["outputs"] = self.outputs + if self.error: + result["error"] = self.error + if self.elapsed_time is not None: + result["elapsed_time"] = self.elapsed_time + if self.thoughts: + result["thoughts"] = self.thoughts + if self.metadata: + result["metadata"] = self.metadata + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TraceEvent": + """Create TraceEvent from dictionary.""" + return cls( + event_id=data.get("event_id", ""), + event_type=TraceEventType(data.get("event_type", "node_started")), + timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(), + component_id=data.get("component_id"), + component_name=data.get("component_name"), + component_type=data.get("component_type"), + inputs=data.get("inputs"), + outputs=data.get("outputs"), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + thoughts=data.get("thoughts"), + metadata=data.get("metadata", {}) + ) + + def to_json(self) -> str: + """Convert trace event to JSON string.""" + return json.dumps(self.to_dict(), ensure_ascii=False) + + +@dataclass +class LLMCallTrace: + """Trace information for LLM API calls.""" + call_id: str + model_name: str + prompt: str + response: Optional[str] = None + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + latency_ms: float = 0.0 + temperature: float = 0.0 + max_tokens: Optional[int] = None + error: Optional[str] = None + started_at: datetime = field(default_factory=datetime.utcnow) + completed_at: Optional[datetime] = None + + def to_dict(self) -> dict[str, Any]: + """Convert LLM call trace to dictionary.""" + return { + "call_id": self.call_id, + "model_name": self.model_name, + "prompt": self.prompt[:500] + "..." if len(self.prompt) > 500 else self.prompt, + "response": self.response[:500] + "..." if self.response and len(self.response) > 500 else self.response, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "latency_ms": self.latency_ms, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "error": self.error, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None + } + + +@dataclass +class RetrievalTrace: + """Trace information for retrieval operations.""" + retrieval_id: str + query: str + knowledge_bases: list[str] = field(default_factory=list) + top_k: int = 10 + similarity_threshold: float = 0.0 + chunks_retrieved: int = 0 + chunks: list[dict[str, Any]] = field(default_factory=list) + latency_ms: float = 0.0 + rerank_enabled: bool = False + error: Optional[str] = None + started_at: datetime = field(default_factory=datetime.utcnow) + completed_at: Optional[datetime] = None + + def to_dict(self) -> dict[str, Any]: + """Convert retrieval trace to dictionary.""" + return { + "retrieval_id": self.retrieval_id, + "query": self.query, + "knowledge_bases": self.knowledge_bases, + "top_k": self.top_k, + "similarity_threshold": self.similarity_threshold, + "chunks_retrieved": self.chunks_retrieved, + "chunks": self.chunks[:5], + "latency_ms": self.latency_ms, + "rerank_enabled": self.rerank_enabled, + "error": self.error, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None + } + + +@dataclass +class ToolCallTrace: + """Trace information for tool/function calls.""" + call_id: str + tool_name: str + tool_type: str + arguments: dict[str, Any] = field(default_factory=dict) + result: Optional[Any] = None + latency_ms: float = 0.0 + error: Optional[str] = None + started_at: datetime = field(default_factory=datetime.utcnow) + completed_at: Optional[datetime] = None + + def to_dict(self) -> dict[str, Any]: + """Convert tool call trace to dictionary.""" + result_str = str(self.result)[:500] if self.result else None + return { + "call_id": self.call_id, + "tool_name": self.tool_name, + "tool_type": self.tool_type, + "arguments": self.arguments, + "result": result_str, + "latency_ms": self.latency_ms, + "error": self.error, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None + } + + +@dataclass +class TraceSession: + """Complete trace session for an agent execution.""" + session_id: str + metadata: TraceMetadata + events: list[TraceEvent] = field(default_factory=list) + llm_calls: list[LLMCallTrace] = field(default_factory=list) + retrievals: list[RetrievalTrace] = field(default_factory=list) + tool_calls: list[ToolCallTrace] = field(default_factory=list) + started_at: datetime = field(default_factory=datetime.utcnow) + completed_at: Optional[datetime] = None + total_elapsed_time: float = 0.0 + status: str = "running" + error: Optional[str] = None + + def add_event(self, event: TraceEvent) -> None: + """Add a trace event to the session.""" + self.events.append(event) + + def add_llm_call(self, llm_call: LLMCallTrace) -> None: + """Add an LLM call trace to the session.""" + self.llm_calls.append(llm_call) + + def add_retrieval(self, retrieval: RetrievalTrace) -> None: + """Add a retrieval trace to the session.""" + self.retrievals.append(retrieval) + + def add_tool_call(self, tool_call: ToolCallTrace) -> None: + """Add a tool call trace to the session.""" + self.tool_calls.append(tool_call) + + def complete(self, error: Optional[str] = None) -> None: + """Mark the session as completed.""" + self.completed_at = datetime.utcnow() + self.total_elapsed_time = (self.completed_at - self.started_at).total_seconds() + self.status = "failed" if error else "completed" + self.error = error + + def to_dict(self) -> dict[str, Any]: + """Convert trace session to dictionary representation.""" + return { + "session_id": self.session_id, + "metadata": self.metadata.to_dict(), + "events": [e.to_dict() for e in self.events], + "llm_calls": [c.to_dict() for c in self.llm_calls], + "retrievals": [r.to_dict() for r in self.retrievals], + "tool_calls": [t.to_dict() for t in self.tool_calls], + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "total_elapsed_time": self.total_elapsed_time, + "status": self.status, + "error": self.error, + "summary": self.get_summary() + } + + def get_summary(self) -> dict[str, Any]: + """Get a summary of the trace session.""" + return { + "total_events": len(self.events), + "total_llm_calls": len(self.llm_calls), + "total_retrievals": len(self.retrievals), + "total_tool_calls": len(self.tool_calls), + "total_tokens": sum(c.total_tokens for c in self.llm_calls), + "total_chunks_retrieved": sum(r.chunks_retrieved for r in self.retrievals), + "nodes_executed": len(set(e.component_id for e in self.events if e.component_id)), + "errors_count": len([e for e in self.events if e.error]) + } + + def to_json(self) -> str: + """Convert trace session to JSON string.""" + return json.dumps(self.to_dict(), ensure_ascii=False, indent=2) diff --git a/api/apps/trace_app.py b/api/apps/trace_app.py new file mode 100644 index 000000000..6e2b13ce1 --- /dev/null +++ b/api/apps/trace_app.py @@ -0,0 +1,477 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Agent Trace API Endpoints + +Provides REST API endpoints for accessing agent execution traces, +including trace retrieval, filtering, analysis, and management. +This addresses Issue #10081: Add Trace Logging for Agent Completions API. +""" + +from datetime import datetime +from quart import request, Response +import json + +from api.apps import login_required, current_user +from api.db.services.trace_service import TraceService +from api.utils.api_utils import ( + get_data_error_result, + get_json_result, + get_request_json, + server_error_response, + validate_request, +) +from common.constants import RetCode + + +@manager.route('/traces', methods=['GET']) # noqa: F821 +@login_required +async def list_traces(): + """ + List trace sessions for the current tenant. + + Query parameters: + - agent_id: Filter by agent ID (optional) + - user_id: Filter by user ID (optional) + - status: Filter by status (running, completed, failed) (optional) + - start_time: Filter by start time ISO format (optional) + - end_time: Filter by end time ISO format (optional) + - page: Page number (default: 1) + - page_size: Items per page (default: 20) + + Returns: + Paginated list of trace sessions + """ + try: + agent_id = request.args.get("agent_id") + user_id = request.args.get("user_id") + status = request.args.get("status") + start_time_str = request.args.get("start_time") + end_time_str = request.args.get("end_time") + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 20)) + + start_time = None + end_time = None + if start_time_str: + start_time = datetime.fromisoformat(start_time_str) + if end_time_str: + end_time = datetime.fromisoformat(end_time_str) + + result = TraceService.list_traces( + tenant_id=current_user.id, + agent_id=agent_id, + user_id=user_id, + status=status, + start_time=start_time, + end_time=end_time, + page=page, + page_size=page_size, + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces/', methods=['GET']) # noqa: F821 +@login_required +async def get_trace(task_id): + """ + Get a specific trace session by task ID. + + Path parameters: + - task_id: The task/trace session ID + + Query parameters: + - format: Output format (streaming, compact, detailed) (default: streaming) + + Returns: + Trace session data + """ + try: + format_type = request.args.get("format", "streaming") + + result = TraceService.format_trace(task_id, format_type) + + if not result: + return get_data_error_result( + message="Trace session not found", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces//events', methods=['GET']) # noqa: F821 +@login_required +async def get_trace_events(task_id): + """ + Get trace events for a specific session. + + Path parameters: + - task_id: The task/trace session ID + + Query parameters: + - event_types: Comma-separated list of event types to filter (optional) + - component_id: Filter by component ID (optional) + - limit: Maximum number of events (default: 100) + - offset: Number of events to skip (default: 0) + + Returns: + List of trace events + """ + try: + event_types_str = request.args.get("event_types") + event_types = event_types_str.split(",") if event_types_str else None + component_id = request.args.get("component_id") + limit = int(request.args.get("limit", 100)) + offset = int(request.args.get("offset", 0)) + + events = TraceService.get_trace_events( + task_id=task_id, + event_types=event_types, + component_id=component_id, + limit=limit, + offset=offset, + ) + + return get_json_result(data={"events": events, "count": len(events)}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces//summary', methods=['GET']) # noqa: F821 +@login_required +async def get_trace_summary(task_id): + """ + Get a summary of a trace session. + + Path parameters: + - task_id: The task/trace session ID + + Returns: + Trace session summary + """ + try: + summary = TraceService.get_trace_summary(task_id) + + if not summary: + return get_data_error_result( + message="Trace session not found", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=summary) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces//analysis', methods=['GET']) # noqa: F821 +@login_required +async def analyze_trace(task_id): + """ + Analyze a trace session and get insights. + + Path parameters: + - task_id: The task/trace session ID + + Returns: + Analysis results including bottlenecks, errors, and recommendations + """ + try: + analysis = TraceService.analyze_trace(task_id) + + if not analysis: + return get_data_error_result( + message="Trace session not found or analysis failed", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=analysis) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces/', methods=['DELETE']) # noqa: F821 +@login_required +async def delete_trace(task_id): + """ + Delete a trace session. + + Path parameters: + - task_id: The task/trace session ID + + Returns: + Success status + """ + try: + success, message = TraceService.delete_trace(task_id) + + if not success: + return get_data_error_result(message=message) + + return get_json_result(data={"task_id": task_id, "message": message}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces/cleanup', methods=['POST']) # noqa: F821 +@login_required +async def cleanup_traces(): + """ + Clean up old trace sessions. + + Request body: + { + "days": 7 // Number of days to keep traces (default: 7) + } + + Returns: + Number of deleted traces + """ + try: + req = await get_request_json() + days = req.get("days", 7) + + deleted, message = TraceService.cleanup_old_traces(days) + + return get_json_result(data={"deleted": deleted, "message": message}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/traces//stream', methods=['GET']) # noqa: F821 +@login_required +async def stream_trace(task_id): + """ + Stream trace events in real-time using Server-Sent Events. + + Path parameters: + - task_id: The task/trace session ID + + Query parameters: + - format: Output format (streaming, compact, detailed) (default: streaming) + + Returns: + SSE stream of trace events + """ + try: + format_type = request.args.get("format", "streaming") + + from agent.trace.trace_collector import get_trace_collector + from agent.trace.trace_formatter import TraceFormatterFactory + + collector = get_trace_collector(task_id) + + if not collector: + return get_data_error_result( + message="Active trace session not found", + code=RetCode.DATA_ERROR + ) + + formatter = TraceFormatterFactory.create(format_type) + + async def generate(): + import asyncio + + for event in collector.get_events(): + yield formatter.format_for_stream(event) + + event_queue = [] + + def on_event(event): + event_queue.append(event) + + collector.subscribe(on_event) + + try: + while collector._is_active: + while event_queue: + event = event_queue.pop(0) + yield formatter.format_for_stream(event) + await asyncio.sleep(0.1) + finally: + collector.unsubscribe(on_event) + + yield "data:[DONE]\n\n" + + resp = Response(generate(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + except Exception as e: + return server_error_response(e) + + +@manager.route('/agents//traces', methods=['GET']) # noqa: F821 +@login_required +async def list_agent_traces(agent_id): + """ + List trace sessions for a specific agent. + + Path parameters: + - agent_id: The agent ID + + Query parameters: + - status: Filter by status (optional) + - page: Page number (default: 1) + - page_size: Items per page (default: 20) + + Returns: + Paginated list of trace sessions for the agent + """ + try: + status = request.args.get("status") + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 20)) + + result = TraceService.list_traces( + tenant_id=current_user.id, + agent_id=agent_id, + status=status, + page=page, + page_size=page_size, + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/agents//completions/trace', methods=['POST']) # noqa: F821 +@login_required +@validate_request("question") +async def agent_completion_with_trace(agent_id): + """ + Execute agent completion with trace logging enabled. + + This endpoint is similar to /agents//completions but includes + trace information in the response, addressing Issue #10081. + + Path parameters: + - agent_id: The agent ID + + Request body: + { + "question": "User question", + "session_id": "Optional session ID", + "stream": true, + "trace_level": "standard", // minimal, standard, detailed, debug + "include_trace": true + } + + Returns: + Agent response with trace information + """ + try: + from api.db.services.canvas_service import completion as agent_completion + + req = await get_request_json() + stream = req.get("stream", True) + trace_level = req.get("trace_level", "standard") + include_trace = req.get("include_trace", True) + + from common.misc_utils import get_uuid + task_id = get_uuid() + + if include_trace: + success, trace_id = TraceService.create_trace_session( + task_id=task_id, + agent_id=agent_id, + session_id=req.get("session_id", ""), + user_id=current_user.id, + tenant_id=current_user.id, + trace_level=trace_level, + ) + + if stream: + async def generate(): + full_content = "" + reference = {} + + async for answer in agent_completion( + tenant_id=current_user.id, + agent_id=agent_id, + **req + ): + try: + ans = json.loads(answer[5:]) + + if ans["event"] == "message": + full_content += ans["data"]["content"] + + if ans.get("data", {}).get("reference"): + reference.update(ans["data"]["reference"]) + + yield answer + except Exception: + continue + + if include_trace: + TraceService.save_trace_session(task_id) + trace_data = TraceService.format_trace(task_id, "compact") + yield f"data:{json.dumps({'event': 'trace', 'data': trace_data}, ensure_ascii=False)}\n\n" + + yield "data:[DONE]\n\n" + + resp = Response(generate(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + full_content = "" + reference = {} + final_ans = None + + async for answer in agent_completion( + tenant_id=current_user.id, + agent_id=agent_id, + **req + ): + try: + ans = json.loads(answer[5:]) + + if ans["event"] == "message": + full_content += ans["data"]["content"] + + if ans.get("data", {}).get("reference"): + reference.update(ans["data"]["reference"]) + + final_ans = ans + except Exception: + continue + + if final_ans: + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + + if include_trace: + TraceService.save_trace_session(task_id) + trace_data = TraceService.format_trace(task_id, "compact") + if final_ans: + final_ans["trace"] = trace_data + + return get_json_result(data=final_ans) + except Exception as e: + return server_error_response(e) diff --git a/api/db/services/trace_service.py b/api/db/services/trace_service.py new file mode 100644 index 000000000..6ea521983 --- /dev/null +++ b/api/db/services/trace_service.py @@ -0,0 +1,482 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Trace Service for Agent Execution Logging + +This module provides the TraceService class that manages trace data persistence, +retrieval, and analysis. It integrates with the trace collector and formatter +modules to provide a complete tracing solution. +""" + +import json +import logging +import time +from datetime import datetime, timedelta +from typing import Any, Optional, Tuple +from collections import defaultdict + +from agent.trace.trace_models import ( + TraceEventType, + TraceLevel, + TraceSession, + TraceEvent, + TraceMetadata, +) +from agent.trace.trace_collector import ( + TraceCollector, + get_trace_collector, + create_trace_collector, +) +from agent.trace.trace_formatter import ( + TraceFormatterFactory, + format_trace_for_api, +) +from rag.utils.redis_conn import REDIS_CONN + + +TRACE_KEY_PREFIX = "agent_trace:" +TRACE_SESSION_TTL = 86400 * 7 +TRACE_EVENT_TTL = 86400 * 3 + + +class TraceService: + """ + Service for managing agent execution traces. + + Provides methods for creating, storing, retrieving, and analyzing + trace data from agent executions. + """ + + @staticmethod + def create_trace_session( + task_id: str, + agent_id: str, + session_id: str, + user_id: str, + tenant_id: str, + trace_level: str = "standard", + ) -> Tuple[bool, str]: + """ + Create a new trace session for an agent execution. + + Args: + task_id: Unique identifier for the task + agent_id: ID of the agent being executed + session_id: ID of the conversation session + user_id: ID of the user + tenant_id: ID of the tenant + trace_level: Verbosity level (minimal, standard, detailed, debug) + + Returns: + Tuple of (success, trace_id or error message) + """ + try: + level = TraceLevel(trace_level) + except ValueError: + level = TraceLevel.STANDARD + + try: + collector = create_trace_collector( + task_id=task_id, + agent_id=agent_id, + session_id=session_id, + user_id=user_id, + tenant_id=tenant_id, + trace_level=level, + ) + + session_data = { + "task_id": task_id, + "agent_id": agent_id, + "session_id": session_id, + "user_id": user_id, + "tenant_id": tenant_id, + "trace_level": trace_level, + "created_at": datetime.utcnow().isoformat(), + "status": "running", + } + + key = f"{TRACE_KEY_PREFIX}session:{task_id}" + REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_data)) + + return True, task_id + except Exception as e: + logging.exception(f"Failed to create trace session: {e}") + return False, str(e) + + @staticmethod + def get_trace_collector(task_id: str) -> Optional[TraceCollector]: + """Get an active trace collector by task ID.""" + return get_trace_collector(task_id) + + @staticmethod + def save_trace_session(task_id: str) -> Tuple[bool, str]: + """ + Save the current trace session to persistent storage. + + Args: + task_id: ID of the task/trace session + + Returns: + Tuple of (success, message) + """ + try: + collector = get_trace_collector(task_id) + if not collector: + return False, "Trace collector not found" + + session = collector.get_session() + session_dict = session.to_dict() + + key = f"{TRACE_KEY_PREFIX}session:{task_id}" + REDIS_CONN.setex(key, TRACE_SESSION_TTL, json.dumps(session_dict)) + + events_key = f"{TRACE_KEY_PREFIX}events:{task_id}" + events_data = json.dumps([e.to_dict() for e in session.events]) + REDIS_CONN.setex(events_key, TRACE_EVENT_TTL, events_data) + + return True, "Trace session saved successfully" + except Exception as e: + logging.exception(f"Failed to save trace session: {e}") + return False, str(e) + + @staticmethod + def get_trace_session(task_id: str) -> Optional[dict[str, Any]]: + """ + Retrieve a trace session by task ID. + + Args: + task_id: ID of the task/trace session + + Returns: + Trace session data or None if not found + """ + try: + collector = get_trace_collector(task_id) + if collector: + return collector.get_session().to_dict() + + key = f"{TRACE_KEY_PREFIX}session:{task_id}" + data = REDIS_CONN.get(key) + if data: + return json.loads(data) + + return None + except Exception as e: + logging.exception(f"Failed to get trace session: {e}") + return None + + @staticmethod + def get_trace_events( + task_id: str, + event_types: Optional[list[str]] = None, + component_id: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> list[dict[str, Any]]: + """ + Retrieve trace events with optional filtering. + + Args: + task_id: ID of the task/trace session + event_types: Filter by event types + component_id: Filter by component ID + limit: Maximum number of events to return + offset: Number of events to skip + + Returns: + List of trace events + """ + try: + collector = get_trace_collector(task_id) + if collector: + events = collector.get_events() + else: + events_key = f"{TRACE_KEY_PREFIX}events:{task_id}" + data = REDIS_CONN.get(events_key) + if not data: + return [] + events = [TraceEvent.from_dict(e) for e in json.loads(data)] + + if event_types: + type_set = set(event_types) + events = [e for e in events if e.event_type.value in type_set] + + if component_id: + events = [e for e in events if e.component_id == component_id] + + events = events[offset:offset + limit] + return [e.to_dict() for e in events] + except Exception as e: + logging.exception(f"Failed to get trace events: {e}") + return [] + + @staticmethod + def get_trace_summary(task_id: str) -> Optional[dict[str, Any]]: + """ + Get a summary of the trace session. + + Args: + task_id: ID of the task/trace session + + Returns: + Summary data or None if not found + """ + try: + collector = get_trace_collector(task_id) + if collector: + return collector.get_summary() + + session_data = TraceService.get_trace_session(task_id) + if session_data and "summary" in session_data: + return session_data["summary"] + + return None + except Exception as e: + logging.exception(f"Failed to get trace summary: {e}") + return None + + @staticmethod + def format_trace( + task_id: str, + format_type: str = "streaming", + **kwargs + ) -> Optional[dict[str, Any]]: + """ + Format a trace session using the specified formatter. + + Args: + task_id: ID of the task/trace session + format_type: Type of formatter (streaming, compact, detailed) + **kwargs: Additional formatter options + + Returns: + Formatted trace data or None if not found + """ + try: + session_data = TraceService.get_trace_session(task_id) + if not session_data: + return None + + collector = get_trace_collector(task_id) + if collector: + session = collector.get_session() + return format_trace_for_api(session, format_type, **kwargs) + + return session_data + except Exception as e: + logging.exception(f"Failed to format trace: {e}") + return None + + @staticmethod + def list_traces( + tenant_id: str, + agent_id: Optional[str] = None, + user_id: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + page: int = 1, + page_size: int = 20, + ) -> dict[str, Any]: + """ + List trace sessions with filtering and pagination. + + Args: + tenant_id: ID of the tenant + agent_id: Filter by agent ID + user_id: Filter by user ID + status: Filter by status (running, completed, failed) + start_time: Filter by start time + end_time: Filter by end time + page: Page number + page_size: Items per page + + Returns: + Paginated list of trace sessions + """ + try: + pattern = f"{TRACE_KEY_PREFIX}session:*" + keys = REDIS_CONN.keys(pattern) + + sessions = [] + for key in keys: + data = REDIS_CONN.get(key) + if data: + session = json.loads(data) + if session.get("tenant_id") == tenant_id: + sessions.append(session) + + if agent_id: + sessions = [s for s in sessions if s.get("agent_id") == agent_id] + if user_id: + sessions = [s for s in sessions if s.get("user_id") == user_id] + if status: + sessions = [s for s in sessions if s.get("status") == status] + + if start_time: + sessions = [s for s in sessions + if datetime.fromisoformat(s.get("created_at", "")) >= start_time] + if end_time: + sessions = [s for s in sessions + if datetime.fromisoformat(s.get("created_at", "")) <= end_time] + + sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True) + + total = len(sessions) + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated = sessions[start_idx:end_idx] + + return { + "total": total, + "page": page, + "page_size": page_size, + "sessions": paginated, + } + except Exception as e: + logging.exception(f"Failed to list traces: {e}") + return {"total": 0, "page": page, "page_size": page_size, "sessions": []} + + @staticmethod + def delete_trace(task_id: str) -> Tuple[bool, str]: + """ + Delete a trace session. + + Args: + task_id: ID of the task/trace session + + Returns: + Tuple of (success, message) + """ + try: + session_key = f"{TRACE_KEY_PREFIX}session:{task_id}" + events_key = f"{TRACE_KEY_PREFIX}events:{task_id}" + + REDIS_CONN.delete(session_key) + REDIS_CONN.delete(events_key) + + return True, "Trace deleted successfully" + except Exception as e: + logging.exception(f"Failed to delete trace: {e}") + return False, str(e) + + @staticmethod + def analyze_trace(task_id: str) -> Optional[dict[str, Any]]: + """ + Analyze a trace session and provide insights. + + Args: + task_id: ID of the task/trace session + + Returns: + Analysis results or None if not found + """ + try: + session_data = TraceService.get_trace_session(task_id) + if not session_data: + return None + + events = TraceService.get_trace_events(task_id, limit=1000) + + analysis = { + "task_id": task_id, + "total_events": len(events), + "event_distribution": defaultdict(int), + "component_execution_times": {}, + "bottlenecks": [], + "errors": [], + "recommendations": [], + } + + for event in events: + event_type = event.get("event_type", "unknown") + analysis["event_distribution"][event_type] += 1 + + if event.get("error"): + analysis["errors"].append({ + "component": event.get("component_name"), + "error": event.get("error"), + "timestamp": event.get("timestamp"), + }) + + if event.get("elapsed_time"): + component = event.get("component_name", "unknown") + if component not in analysis["component_execution_times"]: + analysis["component_execution_times"][component] = [] + analysis["component_execution_times"][component].append( + event.get("elapsed_time") + ) + + for component, times in analysis["component_execution_times"].items(): + avg_time = sum(times) / len(times) + if avg_time > 5.0: + analysis["bottlenecks"].append({ + "component": component, + "avg_execution_time": round(avg_time, 2), + "executions": len(times), + }) + + if analysis["bottlenecks"]: + analysis["recommendations"].append( + "Consider optimizing slow components or adding caching" + ) + if len(analysis["errors"]) > 0: + analysis["recommendations"].append( + "Review and fix error-prone components" + ) + + analysis["event_distribution"] = dict(analysis["event_distribution"]) + + return analysis + except Exception as e: + logging.exception(f"Failed to analyze trace: {e}") + return None + + @staticmethod + def cleanup_old_traces(days: int = 7) -> Tuple[int, str]: + """ + Clean up trace sessions older than specified days. + + Args: + days: Number of days to keep traces + + Returns: + Tuple of (deleted count, message) + """ + try: + cutoff = datetime.utcnow() - timedelta(days=days) + pattern = f"{TRACE_KEY_PREFIX}session:*" + keys = REDIS_CONN.keys(pattern) + + deleted = 0 + for key in keys: + data = REDIS_CONN.get(key) + if data: + session = json.loads(data) + created_at = session.get("created_at") + if created_at: + session_time = datetime.fromisoformat(created_at) + if session_time < cutoff: + task_id = key.split(":")[-1] + TraceService.delete_trace(task_id) + deleted += 1 + + return deleted, f"Deleted {deleted} old trace sessions" + except Exception as e: + logging.exception(f"Failed to cleanup traces: {e}") + return 0, str(e)