LightRAG/service/app/services/history_manager.py

100 lines
3.5 KiB
Python

from sqlalchemy.orm import Session
from app.models.models import ChatMessage, ChatSession, MessageCitation
from typing import List, Dict, Optional
import uuid
class HistoryManager:
def __init__(self, db: Session):
self.db = db
def get_conversation_context(self, session_id: uuid.UUID, max_tokens: int = 4000) -> List[Dict]:
"""
Retrieves conversation history formatted for LLM context, truncated to fit max_tokens.
"""
# Get latest messages first
raw_messages = (
self.db.query(ChatMessage)
.filter(ChatMessage.session_id == session_id)
.order_by(ChatMessage.created_at.desc())
.limit(20) # Safe buffer
.all()
)
context = []
current_tokens = 0
for msg in raw_messages:
# Simple token estimation (approx 4 chars per token)
msg_tokens = msg.token_count or len(msg.content) // 4
if current_tokens + msg_tokens > max_tokens:
break
context.append({"role": msg.role, "content": msg.content})
current_tokens += msg_tokens
return list(reversed(context))
def create_session(self, user_id: str, title: str = None, rag_config: dict = None) -> ChatSession:
session = ChatSession(
user_id=user_id,
title=title,
rag_config=rag_config or {}
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def get_session(self, session_id: uuid.UUID) -> Optional[ChatSession]:
return self.db.query(ChatSession).filter(ChatSession.id == session_id).first()
def list_sessions(self, user_id: str, skip: int = 0, limit: int = 100) -> List[ChatSession]:
return (
self.db.query(ChatSession)
.filter(ChatSession.user_id == user_id)
.order_by(ChatSession.last_message_at.desc())
.offset(skip)
.limit(limit)
.all()
)
def save_message(self, session_id: uuid.UUID, role: str, content: str, token_count: int = None, processing_time: float = None) -> ChatMessage:
message = ChatMessage(
session_id=session_id,
role=role,
content=content,
token_count=token_count,
processing_time=processing_time
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
# Update session last_message_at
session = self.get_session(session_id)
if session:
session.last_message_at = message.created_at
self.db.commit()
return message
def save_citations(self, message_id: uuid.UUID, citations: List[Dict]):
for cit in citations:
content = "\n".join(cit.get("content", []))
citation = MessageCitation(
message_id=message_id,
source_doc_id=cit.get("reference_id", "unknown"),
file_path=cit.get("file_path", "unknown"),
chunk_content=content,
relevance_score=cit.get("relevance_score")
)
self.db.add(citation)
self.db.commit()
def get_session_history(self, session_id: str) -> List[ChatMessage]:
return (
self.db.query(ChatMessage)
.filter(ChatMessage.session_id == session_id)
.order_by(ChatMessage.created_at.asc())
.all()
)