LightRAG/lightrag/citation.py
clssck 663ada943a chore: add citation system and enhance RAG UI components
Add citation tracking and display system across backend and frontend components.
Backend changes include citation.py for document attribution, enhanced query routes
with citation metadata, improved prompt templates, and PostgreSQL schema updates.
Frontend includes CitationMarker component, HoverCard UI, QuerySettings refinements,
and ChatMessage enhancements for displaying document sources. Update dependencies
and docker-compose test configuration for improved development workflow.
2025-12-01 17:50:00 +01:00

431 lines
14 KiB
Python

"""
Citation extraction and footnote generation for LightRAG.
This module provides post-processing capabilities to extract citations from LLM responses
by matching response sentences to source chunks using embedding similarity.
"""
import json
import logging
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Optional
import numpy as np
logger = logging.getLogger(__name__)
# Configuration
CITATION_MIN_SIMILARITY = float(os.getenv("CITATION_MIN_SIMILARITY", "0.5"))
CITATION_MAX_PER_SENTENCE = int(os.getenv("CITATION_MAX_PER_SENTENCE", "3"))
@dataclass
class CitationSpan:
"""Represents a span of text that should be attributed to a source."""
start_char: int # Start position in response text
end_char: int # End position in response text
text: str # The actual text span
reference_ids: list[str] # List of reference IDs supporting this claim
confidence: float # 0.0-1.0 confidence that this claim is supported
@dataclass
class SourceReference:
"""Enhanced reference with full metadata for footnotes."""
reference_id: str
file_path: str
document_title: Optional[str] = None
section_title: Optional[str] = None
page_range: Optional[str] = None
excerpt: Optional[str] = None
chunk_ids: list[str] = field(default_factory=list)
@dataclass
class CitationResult:
"""Complete citation analysis result."""
original_response: str # Raw LLM response
annotated_response: str # Response with [n] markers inserted
footnotes: list[str] # Formatted footnote strings
citations: list[CitationSpan] # Detailed citation spans
references: list[SourceReference] # Enhanced reference list
uncited_claims: list[str] = field(default_factory=list) # Claims without sources
def extract_title_from_path(file_path: str) -> str:
"""Extract a human-readable title from a file path."""
if not file_path:
return "Unknown Source"
path = Path(file_path)
# Get filename without extension
name = path.stem
# Convert snake_case or kebab-case to Title Case
name = name.replace("_", " ").replace("-", " ")
return name.title()
def split_into_sentences(text: str) -> list[dict[str, Any]]:
"""Split text into sentences with their positions.
Returns list of dicts with:
- text: The sentence text
- start: Start character position
- end: End character position
"""
# Improved sentence splitting that handles common edge cases
# Matches: .!? followed by space and capital letter, or end of string
sentence_pattern = r"(?<=[.!?])\s+(?=[A-Z])|(?<=[.!?])$"
sentences = []
current_pos = 0
# Split on sentence boundaries
parts = re.split(sentence_pattern, text)
for part in parts:
part = part.strip()
if not part:
continue
# Find the actual position in original text
start = text.find(part, current_pos)
if start == -1:
start = current_pos
end = start + len(part)
sentences.append({"text": part, "start": start, "end": end})
current_pos = end
return sentences
def compute_similarity(vec1: list[float], vec2: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
a = np.array(vec1)
b = np.array(vec2)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
class CitationExtractor:
"""Post-processor to extract and format citations from LLM responses."""
def __init__(
self,
chunks: list[dict[str, Any]],
references: list[dict[str, str]],
embedding_func: Callable,
min_similarity: float = CITATION_MIN_SIMILARITY,
):
"""Initialize the citation extractor.
Args:
chunks: List of chunk dictionaries with 'content', 'file_path', etc.
references: List of reference dicts with 'reference_id', 'file_path'
embedding_func: Async function to compute embeddings
min_similarity: Minimum similarity threshold for citation matching
"""
self.chunks = chunks
self.references = references
self.embedding_func = embedding_func
self.min_similarity = min_similarity
# Build lookup structures
self._build_chunk_index()
def _build_chunk_index(self):
"""Build index mapping chunk content to reference IDs."""
self.chunk_to_ref: dict[str, str] = {}
self.ref_to_chunks: dict[str, list[dict]] = {}
# Map file_path to reference_id
path_to_ref: dict[str, str] = {}
for ref in self.references:
path_to_ref[ref.get("file_path", "")] = ref.get("reference_id", "")
# Index chunks by reference
for chunk in self.chunks:
file_path = chunk.get("file_path", "")
ref_id = path_to_ref.get(file_path, "")
if ref_id:
self.chunk_to_ref[chunk.get("content", "")[:100]] = ref_id
if ref_id not in self.ref_to_chunks:
self.ref_to_chunks[ref_id] = []
self.ref_to_chunks[ref_id].append(chunk)
async def _find_supporting_chunks(
self, sentence: str, sentence_embedding: list[float]
) -> list[dict[str, Any]]:
"""Find chunks that support a given sentence.
Args:
sentence: The sentence text
sentence_embedding: Pre-computed embedding for the sentence
Returns:
List of matches with reference_id and similarity score
"""
matches = []
for chunk in self.chunks:
chunk_content = chunk.get("content", "")
chunk_embedding = chunk.get("embedding")
# Skip chunks without embeddings (handle both None and empty arrays)
if chunk_embedding is None or (hasattr(chunk_embedding, "__len__") and len(chunk_embedding) == 0):
continue
similarity = compute_similarity(sentence_embedding, chunk_embedding)
if similarity >= self.min_similarity:
file_path = chunk.get("file_path", "")
# Find reference_id for this chunk
ref_id = None
for ref in self.references:
if ref.get("file_path") == file_path:
ref_id = ref.get("reference_id")
break
if ref_id:
matches.append(
{
"reference_id": ref_id,
"similarity": similarity,
"chunk_excerpt": chunk_content[:100],
}
)
# Sort by similarity and deduplicate by reference_id
matches.sort(key=lambda x: x["similarity"], reverse=True)
seen_refs = set()
unique_matches = []
for match in matches:
if match["reference_id"] not in seen_refs:
seen_refs.add(match["reference_id"])
unique_matches.append(match)
if len(unique_matches) >= CITATION_MAX_PER_SENTENCE:
break
return unique_matches
async def extract_citations(
self, response: str, chunk_embeddings: Optional[dict[str, list[float]]] = None
) -> CitationResult:
"""Extract citations by matching response sentences to source chunks.
Algorithm:
1. Split response into sentences
2. For each sentence, compute embedding similarity to all chunks
3. Assign reference_id from best-matching chunk(s) above threshold
4. Generate inline markers and footnotes
Args:
response: The LLM response text
chunk_embeddings: Optional pre-computed chunk embeddings keyed by chunk_id
Returns:
CitationResult with annotated response and footnotes
"""
sentences = split_into_sentences(response)
citations: list[CitationSpan] = []
used_refs: set[str] = set()
uncited_claims: list[str] = []
# Compute embeddings for all sentences at once (batch)
sentence_texts = [s["text"] for s in sentences]
if sentence_texts:
try:
sentence_embeddings = await self.embedding_func(sentence_texts)
except Exception as e:
logger.warning(f"Failed to compute sentence embeddings: {e}")
sentence_embeddings = [None] * len(sentence_texts)
else:
sentence_embeddings = []
# Pre-compute or use provided chunk embeddings
if chunk_embeddings is not None and len(chunk_embeddings) > 0:
for chunk in self.chunks:
chunk_id = chunk.get("id", chunk.get("content", "")[:50])
if chunk_id in chunk_embeddings:
chunk["embedding"] = chunk_embeddings[chunk_id]
else:
# Compute chunk embeddings if not provided
chunk_contents = [c.get("content", "") for c in self.chunks]
if chunk_contents:
try:
computed_embeddings = await self.embedding_func(chunk_contents)
for i, chunk in enumerate(self.chunks):
chunk["embedding"] = computed_embeddings[i]
except Exception as e:
logger.warning(f"Failed to compute chunk embeddings: {e}")
# Match sentences to chunks
for i, sentence in enumerate(sentences):
sentence_emb = sentence_embeddings[i] if i < len(sentence_embeddings) else None
if sentence_emb is None:
uncited_claims.append(sentence["text"])
continue
matches = await self._find_supporting_chunks(sentence["text"], sentence_emb)
if matches:
ref_ids = [m["reference_id"] for m in matches]
confidence = matches[0]["similarity"] if matches else 0.0
citations.append(
CitationSpan(
start_char=sentence["start"],
end_char=sentence["end"],
text=sentence["text"],
reference_ids=ref_ids,
confidence=confidence,
)
)
used_refs.update(ref_ids)
else:
uncited_claims.append(sentence["text"])
# Generate annotated response with inline markers
annotated = self._insert_citation_markers(response, citations)
# Build enhanced references
enhanced_refs = self._enhance_references(used_refs)
# Format footnotes
footnotes = self._format_footnotes(enhanced_refs)
return CitationResult(
original_response=response,
annotated_response=annotated,
footnotes=footnotes,
citations=citations,
references=enhanced_refs,
uncited_claims=uncited_claims,
)
def _insert_citation_markers(
self, response: str, citations: list[CitationSpan]
) -> str:
"""Insert [n] citation markers into response text.
Processes citations in reverse order to preserve character positions.
"""
# Sort by position (descending) to insert from end to beginning
sorted_citations = sorted(citations, key=lambda c: c.end_char, reverse=True)
result = response
for citation in sorted_citations:
if not citation.reference_ids:
continue
# Create marker like [1] or [1,2] for multiple refs
marker = "[" + ",".join(citation.reference_ids) + "]"
# Insert marker after the sentence (at end_char position)
result = result[: citation.end_char] + marker + result[citation.end_char :]
return result
def _enhance_references(self, used_refs: set[str]) -> list[SourceReference]:
"""Build enhanced reference objects with metadata."""
enhanced = []
for ref in self.references:
ref_id = ref.get("reference_id", "")
if ref_id not in used_refs:
continue
file_path = ref.get("file_path", "")
chunks = self.ref_to_chunks.get(ref_id, [])
# Extract first chunk as excerpt
excerpt = None
if chunks:
first_chunk = chunks[0]
content = first_chunk.get("content", "")
excerpt = content[:150] + "..." if len(content) > 150 else content
enhanced.append(
SourceReference(
reference_id=ref_id,
file_path=file_path,
document_title=ref.get("document_title")
or extract_title_from_path(file_path),
section_title=ref.get("section_title"),
page_range=ref.get("page_range"),
excerpt=excerpt,
chunk_ids=[c.get("id", "") for c in chunks if c.get("id")],
)
)
return enhanced
def _format_footnotes(self, references: list[SourceReference]) -> list[str]:
"""Format references as footnote strings.
Format: [n] "Document Title", Section X, pp. Y-Z
"""
footnotes = []
for ref in sorted(references, key=lambda r: int(r.reference_id or "0")):
parts = [f'[{ref.reference_id}] "{ref.document_title}"']
if ref.section_title:
parts.append(f"Section: {ref.section_title}")
if ref.page_range:
parts.append(f"pp. {ref.page_range}")
footnotes.append(", ".join(parts))
return footnotes
async def extract_citations_from_response(
response: str,
chunks: list[dict[str, Any]],
references: list[dict[str, str]],
embedding_func: Callable,
min_similarity: float = CITATION_MIN_SIMILARITY,
) -> CitationResult:
"""Convenience function to extract citations from a response.
Args:
response: The LLM response text
chunks: List of chunk dictionaries
references: List of reference dicts
embedding_func: Async function to compute embeddings
min_similarity: Minimum similarity threshold
Returns:
CitationResult with annotated response and footnotes
"""
extractor = CitationExtractor(
chunks=chunks,
references=references,
embedding_func=embedding_func,
min_similarity=min_similarity,
)
return await extractor.extract_citations(response)