""" Citation extraction and footnote generation for LightRAG. This module provides post-processing capabilities to extract citations from LLM responses by matching response sentences to source chunks using embedding similarity. """ import json import logging import os import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Optional import numpy as np logger = logging.getLogger(__name__) # Configuration CITATION_MIN_SIMILARITY = float(os.getenv("CITATION_MIN_SIMILARITY", "0.5")) CITATION_MAX_PER_SENTENCE = int(os.getenv("CITATION_MAX_PER_SENTENCE", "3")) @dataclass class CitationSpan: """Represents a span of text that should be attributed to a source.""" start_char: int # Start position in response text end_char: int # End position in response text text: str # The actual text span reference_ids: list[str] # List of reference IDs supporting this claim confidence: float # 0.0-1.0 confidence that this claim is supported @dataclass class SourceReference: """Enhanced reference with full metadata for footnotes.""" reference_id: str file_path: str document_title: Optional[str] = None section_title: Optional[str] = None page_range: Optional[str] = None excerpt: Optional[str] = None chunk_ids: list[str] = field(default_factory=list) @dataclass class CitationResult: """Complete citation analysis result.""" original_response: str # Raw LLM response annotated_response: str # Response with [n] markers inserted footnotes: list[str] # Formatted footnote strings citations: list[CitationSpan] # Detailed citation spans references: list[SourceReference] # Enhanced reference list uncited_claims: list[str] = field(default_factory=list) # Claims without sources def extract_title_from_path(file_path: str) -> str: """Extract a human-readable title from a file path.""" if not file_path: return "Unknown Source" path = Path(file_path) # Get filename without extension name = path.stem # Convert snake_case or kebab-case to Title Case name = name.replace("_", " ").replace("-", " ") return name.title() def split_into_sentences(text: str) -> list[dict[str, Any]]: """Split text into sentences with their positions. Returns list of dicts with: - text: The sentence text - start: Start character position - end: End character position """ # Improved sentence splitting that handles common edge cases # Matches: .!? followed by space and capital letter, or end of string sentence_pattern = r"(?<=[.!?])\s+(?=[A-Z])|(?<=[.!?])$" sentences = [] current_pos = 0 # Split on sentence boundaries parts = re.split(sentence_pattern, text) for part in parts: part = part.strip() if not part: continue # Find the actual position in original text start = text.find(part, current_pos) if start == -1: start = current_pos end = start + len(part) sentences.append({"text": part, "start": start, "end": end}) current_pos = end return sentences def compute_similarity(vec1: list[float], vec2: list[float]) -> float: """Compute cosine similarity between two vectors.""" a = np.array(vec1) b = np.array(vec2) norm_a = np.linalg.norm(a) norm_b = np.linalg.norm(b) if norm_a == 0 or norm_b == 0: return 0.0 return float(np.dot(a, b) / (norm_a * norm_b)) class CitationExtractor: """Post-processor to extract and format citations from LLM responses.""" def __init__( self, chunks: list[dict[str, Any]], references: list[dict[str, str]], embedding_func: Callable, min_similarity: float = CITATION_MIN_SIMILARITY, ): """Initialize the citation extractor. Args: chunks: List of chunk dictionaries with 'content', 'file_path', etc. references: List of reference dicts with 'reference_id', 'file_path' embedding_func: Async function to compute embeddings min_similarity: Minimum similarity threshold for citation matching """ self.chunks = chunks self.references = references self.embedding_func = embedding_func self.min_similarity = min_similarity # Build lookup structures self._build_chunk_index() def _build_chunk_index(self): """Build index mapping chunk content to reference IDs.""" self.chunk_to_ref: dict[str, str] = {} self.ref_to_chunks: dict[str, list[dict]] = {} # Map file_path to reference_id path_to_ref: dict[str, str] = {} for ref in self.references: path_to_ref[ref.get("file_path", "")] = ref.get("reference_id", "") # Index chunks by reference for chunk in self.chunks: file_path = chunk.get("file_path", "") ref_id = path_to_ref.get(file_path, "") if ref_id: self.chunk_to_ref[chunk.get("content", "")[:100]] = ref_id if ref_id not in self.ref_to_chunks: self.ref_to_chunks[ref_id] = [] self.ref_to_chunks[ref_id].append(chunk) async def _find_supporting_chunks( self, sentence: str, sentence_embedding: list[float] ) -> list[dict[str, Any]]: """Find chunks that support a given sentence. Args: sentence: The sentence text sentence_embedding: Pre-computed embedding for the sentence Returns: List of matches with reference_id and similarity score """ matches = [] for chunk in self.chunks: chunk_content = chunk.get("content", "") chunk_embedding = chunk.get("embedding") # Skip chunks without embeddings (handle both None and empty arrays) if chunk_embedding is None or (hasattr(chunk_embedding, "__len__") and len(chunk_embedding) == 0): continue similarity = compute_similarity(sentence_embedding, chunk_embedding) if similarity >= self.min_similarity: file_path = chunk.get("file_path", "") # Find reference_id for this chunk ref_id = None for ref in self.references: if ref.get("file_path") == file_path: ref_id = ref.get("reference_id") break if ref_id: matches.append( { "reference_id": ref_id, "similarity": similarity, "chunk_excerpt": chunk_content[:100], } ) # Sort by similarity and deduplicate by reference_id matches.sort(key=lambda x: x["similarity"], reverse=True) seen_refs = set() unique_matches = [] for match in matches: if match["reference_id"] not in seen_refs: seen_refs.add(match["reference_id"]) unique_matches.append(match) if len(unique_matches) >= CITATION_MAX_PER_SENTENCE: break return unique_matches async def extract_citations( self, response: str, chunk_embeddings: Optional[dict[str, list[float]]] = None ) -> CitationResult: """Extract citations by matching response sentences to source chunks. Algorithm: 1. Split response into sentences 2. For each sentence, compute embedding similarity to all chunks 3. Assign reference_id from best-matching chunk(s) above threshold 4. Generate inline markers and footnotes Args: response: The LLM response text chunk_embeddings: Optional pre-computed chunk embeddings keyed by chunk_id Returns: CitationResult with annotated response and footnotes """ sentences = split_into_sentences(response) citations: list[CitationSpan] = [] used_refs: set[str] = set() uncited_claims: list[str] = [] # Compute embeddings for all sentences at once (batch) sentence_texts = [s["text"] for s in sentences] if sentence_texts: try: sentence_embeddings = await self.embedding_func(sentence_texts) except Exception as e: logger.warning(f"Failed to compute sentence embeddings: {e}") sentence_embeddings = [None] * len(sentence_texts) else: sentence_embeddings = [] # Pre-compute or use provided chunk embeddings if chunk_embeddings is not None and len(chunk_embeddings) > 0: for chunk in self.chunks: chunk_id = chunk.get("id", chunk.get("content", "")[:50]) if chunk_id in chunk_embeddings: chunk["embedding"] = chunk_embeddings[chunk_id] else: # Compute chunk embeddings if not provided chunk_contents = [c.get("content", "") for c in self.chunks] if chunk_contents: try: computed_embeddings = await self.embedding_func(chunk_contents) for i, chunk in enumerate(self.chunks): chunk["embedding"] = computed_embeddings[i] except Exception as e: logger.warning(f"Failed to compute chunk embeddings: {e}") # Match sentences to chunks for i, sentence in enumerate(sentences): sentence_emb = sentence_embeddings[i] if i < len(sentence_embeddings) else None if sentence_emb is None: uncited_claims.append(sentence["text"]) continue matches = await self._find_supporting_chunks(sentence["text"], sentence_emb) if matches: ref_ids = [m["reference_id"] for m in matches] confidence = matches[0]["similarity"] if matches else 0.0 citations.append( CitationSpan( start_char=sentence["start"], end_char=sentence["end"], text=sentence["text"], reference_ids=ref_ids, confidence=confidence, ) ) used_refs.update(ref_ids) else: uncited_claims.append(sentence["text"]) # Generate annotated response with inline markers annotated = self._insert_citation_markers(response, citations) # Build enhanced references enhanced_refs = self._enhance_references(used_refs) # Format footnotes footnotes = self._format_footnotes(enhanced_refs) return CitationResult( original_response=response, annotated_response=annotated, footnotes=footnotes, citations=citations, references=enhanced_refs, uncited_claims=uncited_claims, ) def _insert_citation_markers( self, response: str, citations: list[CitationSpan] ) -> str: """Insert [n] citation markers into response text. Processes citations in reverse order to preserve character positions. """ # Sort by position (descending) to insert from end to beginning sorted_citations = sorted(citations, key=lambda c: c.end_char, reverse=True) result = response for citation in sorted_citations: if not citation.reference_ids: continue # Create marker like [1] or [1,2] for multiple refs marker = "[" + ",".join(citation.reference_ids) + "]" # Insert marker after the sentence (at end_char position) result = result[: citation.end_char] + marker + result[citation.end_char :] return result def _enhance_references(self, used_refs: set[str]) -> list[SourceReference]: """Build enhanced reference objects with metadata.""" enhanced = [] for ref in self.references: ref_id = ref.get("reference_id", "") if ref_id not in used_refs: continue file_path = ref.get("file_path", "") chunks = self.ref_to_chunks.get(ref_id, []) # Extract first chunk as excerpt excerpt = None if chunks: first_chunk = chunks[0] content = first_chunk.get("content", "") excerpt = content[:150] + "..." if len(content) > 150 else content enhanced.append( SourceReference( reference_id=ref_id, file_path=file_path, document_title=ref.get("document_title") or extract_title_from_path(file_path), section_title=ref.get("section_title"), page_range=ref.get("page_range"), excerpt=excerpt, chunk_ids=[c.get("id", "") for c in chunks if c.get("id")], ) ) return enhanced def _format_footnotes(self, references: list[SourceReference]) -> list[str]: """Format references as footnote strings. Format: [n] "Document Title", Section X, pp. Y-Z """ footnotes = [] for ref in sorted(references, key=lambda r: int(r.reference_id or "0")): parts = [f'[{ref.reference_id}] "{ref.document_title}"'] if ref.section_title: parts.append(f"Section: {ref.section_title}") if ref.page_range: parts.append(f"pp. {ref.page_range}") footnotes.append(", ".join(parts)) return footnotes async def extract_citations_from_response( response: str, chunks: list[dict[str, Any]], references: list[dict[str, str]], embedding_func: Callable, min_similarity: float = CITATION_MIN_SIMILARITY, ) -> CitationResult: """Convenience function to extract citations from a response. Args: response: The LLM response text chunks: List of chunk dictionaries references: List of reference dicts embedding_func: Async function to compute embeddings min_similarity: Minimum similarity threshold Returns: CitationResult with annotated response and footnotes """ extractor = CitationExtractor( chunks=chunks, references=references, embedding_func=embedding_func, min_similarity=min_similarity, ) return await extractor.extract_citations(response)