From 6872f085d1667438288f25954df70fcaf381775d Mon Sep 17 00:00:00 2001 From: Saswat Date: Thu, 9 Oct 2025 17:38:43 +0530 Subject: [PATCH 1/2] feat: Enhance document processing with page tracking and reference validation - Added optional page tracking fields (start_page, end_page, pages) to TextChunkSchema. - Updated LightRAG class to handle page metadata during document processing. - Implemented validation for LLM responses to ensure only valid reference IDs are used. - Enhanced chunking functions to include page data for better context management. - Improved reference generation to include page ranges for citations. - Added PDF extraction methods to capture page-level data using PyPDF2 and Docling. --- lightrag/api/routers/document_routes.py | 109 +++++-- lightrag/base.py | 4 + lightrag/kg/postgres_impl.py | 40 ++- lightrag/lightrag.py | 33 ++- lightrag/operate.py | 374 +++++++++++++++++------- lightrag/prompt.py | 60 ++-- lightrag/utils.py | 129 ++++---- 7 files changed, 520 insertions(+), 229 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 7e44b57d..5dd8f651 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -859,6 +859,68 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str return f"{base_name}_{timestamp}{extension}" +async def _extract_pdf_with_docling(file_path: Path) -> str: + """Extract text from PDF using Docling engine.""" + if not pm.is_installed("docling"): # type: ignore + pm.install("docling") + from docling.document_converter import DocumentConverter # type: ignore + + converter = DocumentConverter() + result = converter.convert(file_path) + return result.document.export_to_markdown() + + +async def _extract_pdf_with_pypdf2(file_bytes: bytes) -> tuple[str, list[dict]]: + """Extract text and page data from PDF using PyPDF2.""" + if not pm.is_installed("pypdf2"): # type: ignore + pm.install("pypdf2") + from PyPDF2 import PdfReader # type: ignore + from io import BytesIO + + pdf_file = BytesIO(file_bytes) + reader = PdfReader(pdf_file) + content = "" + page_data = [] + char_position = 0 + + for page_num, page in enumerate(reader.pages, start=1): + page_text = page.extract_text() + "\n" + page_start = char_position + page_end = char_position + len(page_text) + + page_data.append({ + "page_number": page_num, + "content": page_text, + "char_start": page_start, + "char_end": page_end, + }) + + content += page_text + char_position = page_end + + return content, page_data + + +async def _handle_file_processing_error( + rag: LightRAG, + filename: str, + error_type: str, + error_msg: str, + file_size: int, + track_id: str +) -> None: + """Handle file processing errors consistently.""" + error_files = [{ + "file_path": filename, + "error_description": f"[File Extraction]{error_type}", + "original_error": error_msg, + "file_size": file_size, + }] + + await rag.apipeline_enqueue_error_documents(error_files, track_id) + logger.error(f"[File Extraction]{error_type} for {filename}: {error_msg}") + + async def pipeline_enqueue_file( rag: LightRAG, file_path: Path, track_id: str = None ) -> tuple[bool, str]: @@ -878,6 +940,7 @@ async def pipeline_enqueue_file( try: content = "" + page_data = None # Initialize page data (will be populated for PDFs) ext = file_path.suffix.lower() file_size = 0 @@ -1029,38 +1092,15 @@ async def pipeline_enqueue_file( case ".pdf": try: + page_data = [] if global_args.document_loading_engine == "DOCLING": - if not pm.is_installed("docling"): # type: ignore - pm.install("docling") - from docling.document_converter import DocumentConverter # type: ignore - - converter = DocumentConverter() - result = converter.convert(file_path) - content = result.document.export_to_markdown() + content = await _extract_pdf_with_docling(file_path) + # TODO: Extract page-level data from Docling else: - if not pm.is_installed("pypdf2"): # type: ignore - pm.install("pypdf2") - from PyPDF2 import PdfReader # type: ignore - from io import BytesIO - - pdf_file = BytesIO(file) - reader = PdfReader(pdf_file) - for page in reader.pages: - content += page.extract_text() + "\n" + content, page_data = await _extract_pdf_with_pypdf2(file) except Exception as e: - error_files = [ - { - "file_path": str(file_path.name), - "error_description": "[File Extraction]PDF processing error", - "original_error": f"Failed to extract text from PDF: {str(e)}", - "file_size": file_size, - } - ] - await rag.apipeline_enqueue_error_documents( - error_files, track_id - ) - logger.error( - f"[File Extraction]Error processing PDF {file_path.name}: {str(e)}" + await _handle_file_processing_error( + rag, file_path.name, "PDF processing error", str(e), file_size, track_id ) return False, track_id @@ -1239,8 +1279,17 @@ async def pipeline_enqueue_file( return False, track_id try: + # Pass page_data if it was collected (only for PDFs with PyPDF2) + page_data_to_pass = [page_data] if page_data is not None and len(page_data) > 0 else None + + # Debug logging + if page_data_to_pass: + logger.info(f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages") + else: + logger.debug(f"No page metadata for {file_path.name} (non-PDF or extraction failed)") + await rag.apipeline_enqueue_documents( - content, file_paths=file_path.name, track_id=track_id + content, file_paths=file_path.name, track_id=track_id, page_data_list=page_data_to_pass ) logger.info( diff --git a/lightrag/base.py b/lightrag/base.py index b9ebeca8..d25dfc51 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -77,6 +77,10 @@ class TextChunkSchema(TypedDict): content: str full_doc_id: str chunk_order_index: int + # Page tracking fields (optional for backward compatibility) + start_page: int | None # 1-based page number where chunk starts + end_page: int | None # 1-based page number where chunk ends (inclusive) + pages: list[int] | None # List of all pages this chunk spans [start, end] T = TypeVar("T") diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 50c2108f..e3f333db 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1782,6 +1782,9 @@ class PGKVStorage(BaseKVStorage): "content": v["content"], "file_path": v["file_path"], "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), + "start_page": v.get("start_page"), # Optional page fields + "end_page": v.get("end_page"), + "pages": json.dumps(v.get("pages")) if v.get("pages") is not None else None, "create_time": current_time, "update_time": current_time, } @@ -1794,6 +1797,7 @@ class PGKVStorage(BaseKVStorage): "content": v["content"], "doc_name": v.get("file_path", ""), # Map file_path to doc_name "workspace": self.workspace, + "page_data": json.dumps(v.get("page_data")) if v.get("page_data") is not None else None, } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -1949,6 +1953,9 @@ class PGVectorStorage(BaseVectorStorage): "content": item["content"], "content_vector": json.dumps(item["__vector__"].tolist()), "file_path": item["file_path"], + "start_page": item.get("start_page"), # Optional page fields + "end_page": item.get("end_page"), + "pages": json.dumps(item.get("pages")) if item.get("pages") is not None else None, "create_time": current_time, "update_time": current_time, } @@ -4508,6 +4515,7 @@ TABLES = { doc_name VARCHAR(1024), content TEXT, meta JSONB, + page_data JSONB, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) @@ -4523,6 +4531,9 @@ TABLES = { content TEXT, file_path TEXT NULL, llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, + start_page INTEGER NULL, + end_page INTEGER NULL, + pages JSONB NULL, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4538,6 +4549,9 @@ TABLES = { content TEXT, content_vector VECTOR({os.environ.get("EMBEDDING_DIM", 1024)}), file_path TEXT NULL, + start_page INTEGER NULL, + end_page INTEGER NULL, + pages JSONB NULL, create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) @@ -4632,12 +4646,14 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content, - COALESCE(doc_name, '') as file_path + COALESCE(doc_name, '') as file_path, + page_data FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 """, "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, chunk_order_index, full_doc_id, file_path, COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, + start_page, end_page, pages, EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, EXTRACT(EPOCH FROM update_time)::BIGINT as update_time FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 @@ -4684,11 +4700,12 @@ SQL_TEMPLATES = { FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", - "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace) - VALUES ($1, $2, $3, $4) + "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace, page_data) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (workspace,id) DO UPDATE SET content = $2, doc_name = $3, + page_data = $5, update_time = CURRENT_TIMESTAMP """, "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,chunk_id,cache_type,queryparam) @@ -4703,8 +4720,8 @@ SQL_TEMPLATES = { """, "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, file_path, llm_cache_list, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + start_page, end_page, pages, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4712,6 +4729,9 @@ SQL_TEMPLATES = { content = EXCLUDED.content, file_path=EXCLUDED.file_path, llm_cache_list=EXCLUDED.llm_cache_list, + start_page=EXCLUDED.start_page, + end_page=EXCLUDED.end_page, + pages=EXCLUDED.pages, update_time = EXCLUDED.update_time """, "upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count, @@ -4733,8 +4753,8 @@ SQL_TEMPLATES = { # SQL for VectorStorage "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + start_page, end_page, pages, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (workspace,id) DO UPDATE SET tokens=EXCLUDED.tokens, chunk_order_index=EXCLUDED.chunk_order_index, @@ -4742,6 +4762,9 @@ SQL_TEMPLATES = { content = EXCLUDED.content, content_vector=EXCLUDED.content_vector, file_path=EXCLUDED.file_path, + start_page=EXCLUDED.start_page, + end_page=EXCLUDED.end_page, + pages=EXCLUDED.pages, update_time = EXCLUDED.update_time """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, @@ -4790,6 +4813,9 @@ SQL_TEMPLATES = { SELECT c.id, c.content, c.file_path, + c.start_page, + c.end_page, + c.pages, EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at FROM LIGHTRAG_VDB_CHUNKS c WHERE c.workspace = $1 diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b4345405..18c793ae 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -550,7 +550,7 @@ class LightRAG: namespace=NameSpace.VECTOR_STORE_CHUNKS, workspace=self.workspace, embedding_func=self.embedding_func, - meta_fields={"full_doc_id", "content", "file_path"}, + meta_fields={"full_doc_id", "content", "file_path", "start_page", "end_page", "pages"}, ) # Initialize document status storage @@ -1011,6 +1011,7 @@ class LightRAG: ids: list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, + page_data_list: list[list[dict[str, Any]]] | None = None, ) -> str: """ Pipeline for Processing Documents @@ -1025,6 +1026,8 @@ class LightRAG: ids: list of unique document IDs, if not provided, MD5 hash IDs will be generated file_paths: list of file paths corresponding to each document, used for citation track_id: tracking ID for monitoring processing status, if not provided, will be generated with "enqueue" prefix + page_data_list: Optional list of page metadata for each document. Each entry is a list of dicts with + {"page_number": int, "content": str, "char_start": int, "char_end": int} Returns: str: tracking ID for monitoring processing status @@ -1050,6 +1053,16 @@ class LightRAG: else: # If no file paths provided, use placeholder file_paths = ["unknown_source"] * len(input) + + # Handle page_data_list + if page_data_list is not None: + if len(page_data_list) != len(input): + raise ValueError( + "Number of page_data entries must match the number of documents" + ) + else: + # If no page data provided, use empty lists + page_data_list = [None] * len(input) # 1. Validate ids if provided or generate MD5 hash IDs and remove duplicate contents if ids is not None: @@ -1063,31 +1076,32 @@ class LightRAG: # Generate contents dict and remove duplicates in one pass unique_contents = {} - for id_, doc, path in zip(ids, input, file_paths): + for id_, doc, path, page_data in zip(ids, input, file_paths, page_data_list): cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_contents: - unique_contents[cleaned_content] = (id_, path) + unique_contents[cleaned_content] = (id_, path, page_data) # Reconstruct contents with unique content contents = { - id_: {"content": content, "file_path": file_path} - for content, (id_, file_path) in unique_contents.items() + id_: {"content": content, "file_path": file_path, "page_data": page_data} + for content, (id_, file_path, page_data) in unique_contents.items() } else: # Clean input text and remove duplicates in one pass unique_content_with_paths = {} - for doc, path in zip(input, file_paths): + for doc, path, page_data in zip(input, file_paths, page_data_list): cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_content_with_paths: - unique_content_with_paths[cleaned_content] = path + unique_content_with_paths[cleaned_content] = (path, page_data) # Generate contents dict of MD5 hash IDs and documents with paths contents = { compute_mdhash_id(content, prefix="doc-"): { "content": content, "file_path": path, + "page_data": page_data, } - for content, path in unique_content_with_paths.items() + for content, (path, page_data) in unique_content_with_paths.items() } # 2. Generate document initial status (without content) @@ -1142,6 +1156,7 @@ class LightRAG: doc_id: { "content": contents[doc_id]["content"], "file_path": contents[doc_id]["file_path"], + "page_data": contents[doc_id].get("page_data"), # Optional page metadata } for doc_id in new_docs.keys() } @@ -1525,6 +1540,7 @@ class LightRAG: f"Document content not found in full_docs for doc_id: {doc_id}" ) content = content_data["content"] + page_data = content_data.get("page_data") # Optional page metadata # Generate chunks from document chunks: dict[str, Any] = { @@ -1541,6 +1557,7 @@ class LightRAG: split_by_character_only, self.chunk_overlap_token_size, self.chunk_token_size, + page_data, # Pass page metadata to chunking function ) } diff --git a/lightrag/operate.py b/lightrag/operate.py index 0551fdb5..34981c06 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -63,6 +63,38 @@ from dotenv import load_dotenv load_dotenv(dotenv_path=".env", override=False) +def validate_llm_references(response: str, valid_ref_ids: set[str]) -> tuple[str, bool]: + """ + Validate that LLM response only uses valid reference IDs. + + Args: + response: The LLM response text + valid_ref_ids: Set of valid reference IDs from the reference list + + Returns: + Tuple of (cleaned_response, is_valid) + """ + import re + + # Find all reference patterns like [1], [2], etc. + ref_pattern = r'\[(\d+)\]' + matches = re.findall(ref_pattern, response) + + invalid_refs = [] + for ref_id in matches: + if ref_id not in valid_ref_ids: + invalid_refs.append(ref_id) + + if invalid_refs: + logger.warning(f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}") + # Remove invalid references from the response + for invalid_ref in invalid_refs: + response = re.sub(rf'\[{invalid_ref}\](?:\s*\([^)]*\))?', '', response) + return response, False + + return response, True + + def chunking_by_token_size( tokenizer: Tokenizer, content: str, @@ -70,51 +102,102 @@ def chunking_by_token_size( split_by_character_only: bool = False, overlap_token_size: int = 128, max_token_size: int = 1024, + page_data: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: + """Chunk content by token size with optional page tracking.""" + + def _calculate_page_range(start_char: int, end_char: int) -> tuple[int | None, int | None, list[int] | None]: + if not page_data: + return None, None, None + + pages = set() + start_page = end_page = None + + for page in page_data: + page_num = page["page_number"] + page_start = page["char_start"] + page_end = page["char_end"] + + if start_char < page_end and end_char > page_start: + pages.add(page_num) + start_page = min(start_page, page_num) if start_page else page_num + end_page = max(end_page, page_num) if end_page else page_num + + return start_page, end_page, sorted(pages) if pages else None + + def _estimate_char_positions(token_start: int, token_end: int, total_tokens: int, total_chars: int) -> tuple[int, int]: + if total_tokens == 0: + return 0, total_chars + start_char = int((token_start / total_tokens) * total_chars) + end_char = int((token_end / total_tokens) * total_chars) + return start_char, end_char + + def _create_chunk_dict(token_count: int, content: str, index: int, start_char: int, end_char: int) -> dict[str, Any]: + chunk = { + "tokens": token_count, + "content": content.strip(), + "chunk_order_index": index, + } + + if page_data: + start_page, end_page, pages = _calculate_page_range(start_char, end_char) + chunk.update({ + "start_page": start_page, + "end_page": end_page, + "pages": pages + }) + + return chunk + tokens = tokenizer.encode(content) - results: list[dict[str, Any]] = [] + total_tokens = len(tokens) + total_chars = len(content) + results = [] + if split_by_character: raw_chunks = content.split(split_by_character) - new_chunks = [] - if split_by_character_only: - for chunk in raw_chunks: - _tokens = tokenizer.encode(chunk) - new_chunks.append((len(_tokens), chunk)) - else: - for chunk in raw_chunks: - _tokens = tokenizer.encode(chunk) - if len(_tokens) > max_token_size: - for start in range( - 0, len(_tokens), max_token_size - overlap_token_size - ): - chunk_content = tokenizer.decode( - _tokens[start : start + max_token_size] - ) - new_chunks.append( - (min(max_token_size, len(_tokens) - start), chunk_content) - ) - else: - new_chunks.append((len(_tokens), chunk)) - for index, (_len, chunk) in enumerate(new_chunks): - results.append( - { - "tokens": _len, - "content": chunk.strip(), - "chunk_order_index": index, - } - ) + chunks_with_positions = [] + char_pos = 0 + + for chunk_text in raw_chunks: + chunk_tokens = tokenizer.encode(chunk_text) + chunk_start = char_pos + chunk_end = char_pos + len(chunk_text) + + if split_by_character_only or len(chunk_tokens) <= max_token_size: + chunks_with_positions.append((len(chunk_tokens), chunk_text, chunk_start, chunk_end)) + else: + # Split large chunks by tokens + for token_start in range(0, len(chunk_tokens), max_token_size - overlap_token_size): + token_end = min(token_start + max_token_size, len(chunk_tokens)) + chunk_content = tokenizer.decode(chunk_tokens[token_start:token_end]) + + # Estimate character positions within the chunk + ratio_start = token_start / len(chunk_tokens) + ratio_end = token_end / len(chunk_tokens) + sub_start = chunk_start + int(len(chunk_text) * ratio_start) + sub_end = chunk_start + int(len(chunk_text) * ratio_end) + + chunks_with_positions.append(( + token_end - token_start, + chunk_content, + sub_start, + sub_end + )) + + char_pos = chunk_end + len(split_by_character) + + for index, (token_count, chunk_text, start_char, end_char) in enumerate(chunks_with_positions): + results.append(_create_chunk_dict(token_count, chunk_text, index, start_char, end_char)) else: - for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) - ): - chunk_content = tokenizer.decode(tokens[start : start + max_token_size]) - results.append( - { - "tokens": min(max_token_size, len(tokens) - start), - "content": chunk_content.strip(), - "chunk_order_index": index, - } - ) + # Token-based chunking + for index, token_start in enumerate(range(0, total_tokens, max_token_size - overlap_token_size)): + token_end = min(token_start + max_token_size, total_tokens) + chunk_content = tokenizer.decode(tokens[token_start:token_end]) + start_char, end_char = _estimate_char_positions(token_start, token_end, total_tokens, total_chars) + + results.append(_create_chunk_dict(token_end - token_start, chunk_content, index, start_char, end_char)) + return results @@ -2384,6 +2467,13 @@ async def kg_query( " == LLM cache == Query cache hit, using cached response as query result" ) response = cached_response + + # Validate references in cached response too + valid_ref_ids = global_config.get('_valid_reference_ids', set()) + if valid_ref_ids: + response, is_valid = validate_llm_references(response, valid_ref_ids) + if not is_valid: + logger.warning("Cached LLM response contained invalid references and has been cleaned") else: response = await use_model_func( user_query, @@ -2392,6 +2482,13 @@ async def kg_query( enable_cot=True, stream=query_param.stream, ) + + # Validate references in the response + valid_ref_ids = global_config.get('_valid_reference_ids', set()) + if valid_ref_ids: + response, is_valid = validate_llm_references(response, valid_ref_ids) + if not is_valid: + logger.warning("LLM response contained invalid references and has been cleaned") if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { @@ -2628,6 +2725,10 @@ async def _get_vector_context( "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type "chunk_id": result.get("id"), # Add chunk_id for deduplication + # Include page metadata if available + "start_page": result.get("start_page"), + "end_page": result.get("end_page"), + "pages": result.get("pages"), } valid_chunks.append(chunk_with_metadata) @@ -3027,60 +3128,43 @@ async def _merge_all_chunks( query_embedding=query_embedding, ) - # Round-robin merge chunks from different sources with deduplication - merged_chunks = [] - seen_chunk_ids = set() - max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) - origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks) + def _extract_chunk_metadata(chunk: dict) -> dict: + """Extract and preserve essential chunk metadata including page tracking.""" + chunk_id = chunk.get("chunk_id") or chunk.get("id") + metadata = { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, + } + + # Preserve page metadata if available + for field in ["start_page", "end_page", "pages"]: + if chunk.get(field) is not None: + metadata[field] = chunk.get(field) + + return metadata - for i in range(max_len): - # Add from vector chunks first (Naive mode) - if i < len(vector_chunks): - chunk = vector_chunks[i] - chunk_id = chunk.get("chunk_id") or chunk.get("id") - if chunk_id and chunk_id not in seen_chunk_ids: - seen_chunk_ids.add(chunk_id) - merged_chunks.append( - { - "content": chunk["content"], - "file_path": chunk.get("file_path", "unknown_source"), - "chunk_id": chunk_id, - } - ) + def _merge_chunks_round_robin(chunk_sources: list[list[dict]]) -> list[dict]: + """Merge chunks from multiple sources using round-robin with deduplication.""" + merged = [] + seen_ids = set() + max_len = max(len(source) for source in chunk_sources) + total_original = sum(len(source) for source in chunk_sources) + + for i in range(max_len): + for source in chunk_sources: + if i < len(source): + chunk = source[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + + if chunk_id and chunk_id not in seen_ids: + seen_ids.add(chunk_id) + merged.append(_extract_chunk_metadata(chunk)) + + logger.info(f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})") + return merged - # Add from entity chunks (Local mode) - if i < len(entity_chunks): - chunk = entity_chunks[i] - chunk_id = chunk.get("chunk_id") or chunk.get("id") - if chunk_id and chunk_id not in seen_chunk_ids: - seen_chunk_ids.add(chunk_id) - merged_chunks.append( - { - "content": chunk["content"], - "file_path": chunk.get("file_path", "unknown_source"), - "chunk_id": chunk_id, - } - ) - - # Add from relation chunks (Global mode) - if i < len(relation_chunks): - chunk = relation_chunks[i] - chunk_id = chunk.get("chunk_id") or chunk.get("id") - if chunk_id and chunk_id not in seen_chunk_ids: - seen_chunk_ids.add(chunk_id) - merged_chunks.append( - { - "content": chunk["content"], - "file_path": chunk.get("file_path", "unknown_source"), - "chunk_id": chunk_id, - } - ) - - logger.info( - f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})" - ) - - return merged_chunks + return _merge_chunks_round_robin([vector_chunks, entity_chunks, relation_chunks]) async def _build_llm_context( @@ -3179,6 +3263,12 @@ async def _build_llm_context( ) # Generate reference list from truncated chunks using the new common function + # Debug: Check if chunks have pages before reference generation + if truncated_chunks: + sample_chunk = truncated_chunks[0] + has_pages = "pages" in sample_chunk + logger.info(f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}") + reference_list, truncated_chunks = generate_reference_list_from_chunks( truncated_chunks ) @@ -3233,11 +3323,41 @@ async def _build_llm_context( text_units_str = "\n".join( json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context ) - reference_list_str = "\n".join( - f"[{ref['reference_id']}] {ref['file_path']}" - for ref in reference_list - if ref["reference_id"] - ) + + # Format reference list with page numbers if available + formatted_references = [] + for ref in reference_list: + if not ref.get("reference_id"): + continue + + file_path = ref['file_path'] + ref_id = ref['reference_id'] + + # Add page numbers if available + pages = ref.get('pages') + if pages and len(pages) > 0: + if len(pages) == 1: + # Single page: "document.pdf (p. 5)" + citation = f"[{ref_id}] {file_path} (p. {pages[0]})" + else: + # Multiple pages: "document.pdf (pp. {first}-{last})" + citation = f"[{ref_id}] {file_path} (pp. {pages[0]}-{pages[-1]})" + else: + # No page info: "document.txt" + citation = f"[{ref_id}] {file_path}" + + formatted_references.append(citation) + + reference_list_str = "\n".join(formatted_references) + + # Debug: Log what references are being sent to the LLM + logger.info(f"Reference list for LLM ({len(formatted_references)} refs):") + for ref_line in formatted_references[:3]: # Show first 3 + logger.info(f" {ref_line}") + + # Store valid reference IDs for validation + valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')} + global_config['_valid_reference_ids'] = valid_ref_ids result = kg_context_template.format( entities_str=entities_str, @@ -3646,6 +3766,12 @@ async def _find_related_text_unit_from_entities( chunk_data_copy = chunk_data.copy() chunk_data_copy["source_type"] = "entity" chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + + # Debug: Check if page metadata is present + if i == 0: # Log first chunk only + has_pages = "pages" in chunk_data_copy + logger.info(f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}") + result_chunks.append(chunk_data_copy) # Update chunk tracking if provided @@ -4124,11 +4250,41 @@ async def naive_query( text_units_str = "\n".join( json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context ) - reference_list_str = "\n".join( - f"[{ref['reference_id']}] {ref['file_path']}" - for ref in reference_list - if ref["reference_id"] - ) + + # Format reference list with page numbers if available + formatted_references = [] + for ref in reference_list: + if not ref.get("reference_id"): + continue + + file_path = ref['file_path'] + ref_id = ref['reference_id'] + + # Add page numbers if available + pages = ref.get('pages') + if pages and len(pages) > 0: + if len(pages) == 1: + # Single page: "document.pdf (p. 5)" + citation = f"[{ref_id}] {file_path} (p. {pages[0]})" + else: + # Multiple pages: "document.pdf (pp. {first}-{last})" + citation = f"[{ref_id}] {file_path} (pp. {pages[0]}-{pages[-1]})" + else: + # No page info: "document.txt" + citation = f"[{ref_id}] {file_path}" + + formatted_references.append(citation) + + reference_list_str = "\n".join(formatted_references) + + # Debug: Log what references are being sent to the LLM + logger.info(f"Reference list for LLM ({len(formatted_references)} refs):") + for ref_line in formatted_references[:3]: # Show first 3 + logger.info(f" {ref_line}") + + # Store valid reference IDs for validation + valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')} + global_config['_valid_reference_ids'] = valid_ref_ids naive_context_template = PROMPTS["naive_query_context"] context_content = naive_context_template.format( @@ -4173,6 +4329,13 @@ async def naive_query( " == LLM cache == Query cache hit, using cached response as query result" ) response = cached_response + + # Validate references in cached response too + valid_ref_ids = global_config.get('_valid_reference_ids', set()) + if valid_ref_ids: + response, is_valid = validate_llm_references(response, valid_ref_ids) + if not is_valid: + logger.warning("Cached LLM response contained invalid references and has been cleaned") else: response = await use_model_func( user_query, @@ -4181,6 +4344,13 @@ async def naive_query( enable_cot=True, stream=query_param.stream, ) + + # Validate references in the response + valid_ref_ids = global_config.get('_valid_reference_ids', set()) + if valid_ref_ids: + response, is_valid = validate_llm_references(response, valid_ref_ids) + if not is_valid: + logger.warning("LLM response contained invalid references and has been cleaned") if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 8faebbf4..b4b6be84 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -228,7 +228,7 @@ Consider the conversation history if provided to maintain conversational flow an - Scrutinize both `Knowledge Graph Data` and `Document Chunks` in the **Context**. Identify and extract all pieces of information that are directly relevant to answering the user query. - Weave the extracted facts into a coherent and logical response. Your own knowledge must ONLY be used to formulate fluent sentences and connect ideas, NOT to introduce any external information. - Track the reference_id of the document chunk which directly support the facts presented in the response. Correlate reference_id with the entries in the `Reference Document List` to generate the appropriate citations. - - Generate a references section at the end of the response. Each reference document must directly support the facts presented in the response. + - **MANDATORY**: Generate a "### References" section at the end of EVERY response. Copy the citations EXACTLY as they appear in the `Reference Document List` section of the Context. Only use references that are explicitly provided in the Reference Document List - do not create or invent any references. - Do not generate anything after the reference section. 2. Content & Grounding: @@ -240,21 +240,27 @@ Consider the conversation history if provided to maintain conversational flow an - The response MUST utilize Markdown formatting for enhanced clarity and structure (e.g., headings, bold text, bullet points). - The response should be presented in {response_type}. -4. References Section Format: - - The References section should be under heading: `### References` - - Reference list entries should adhere to the format: `* [n] Document Title`. Do not include a caret (`^`) after opening square bracket (`[`). - - The Document Title in the citation must retain its original language. - - Output each citation on an individual line - - Provide maximum of 5 most relevant citations. - - Do not generate footnotes section or any comment, summary, or explanation after the references. +4. References Section Format (MANDATORY - MUST ALWAYS INCLUDE): + - **ALWAYS** end your response with a `### References` section - THIS IS REQUIRED + - **ONLY USE REFERENCES FROM THE PROVIDED REFERENCE DOCUMENT LIST** - Do not create or invent any references + - **COPY the citations EXACTLY as shown in the `Reference Document List`** - including any page numbers like "(p. 5)" or "(pp. 3-5)" + - **CRITICAL**: When you see page numbers in citations like "(p. 5)" or "(pp. 3-5)", you MUST include them in your References section + - Reference list entries should adhere to the format: `- [n] Document Title (p. X)` or `- [n] Document Title (pp. X-Y)` + - Do not include a caret (`^`) after opening square bracket (`[`) + - The Document Title in the citation must retain its original language + - Output each citation on an individual line starting with a dash (`-`) + - Provide maximum of 5 most relevant citations from the Reference Document List + - **NEVER** omit page numbers if they are present in the Reference Document List + - **NEVER** create references that are not in the Reference Document List + - Do not generate footnotes section or any comment, summary, or explanation after the references -5. Reference Section Example: +5. Reference Section Example (copy the format shown in Reference Document List): ``` ### References -- [1] Document Title One -- [2] Document Title Two -- [3] Document Title Three +- [1] document.pdf (pp. 1-3) +- [2] report.pdf (p. 5) +- [3] notes.txt ``` 6. Additional Instructions: {user_prompt} @@ -282,7 +288,7 @@ Consider the conversation history if provided to maintain conversational flow an - Scrutinize `Document Chunks` in the **Context**. Identify and extract all pieces of information that are directly relevant to answering the user query. - Weave the extracted facts into a coherent and logical response. Your own knowledge must ONLY be used to formulate fluent sentences and connect ideas, NOT to introduce any external information. - Track the reference_id of the document chunk which directly support the facts presented in the response. Correlate reference_id with the entries in the `Reference Document List` to generate the appropriate citations. - - Generate a **References** section at the end of the response. Each reference document must directly support the facts presented in the response. + - **MANDATORY**: Generate a "### References" section at the end of EVERY response. Copy the citations EXACTLY as they appear in the `Reference Document List` section of the Context. Only use references that are explicitly provided in the Reference Document List - do not create or invent any references. - Do not generate anything after the reference section. 2. Content & Grounding: @@ -294,21 +300,27 @@ Consider the conversation history if provided to maintain conversational flow an - The response MUST utilize Markdown formatting for enhanced clarity and structure (e.g., headings, bold text, bullet points). - The response should be presented in {response_type}. -4. References Section Format: - - The References section should be under heading: `### References` - - Reference list entries should adhere to the format: `* [n] Document Title`. Do not include a caret (`^`) after opening square bracket (`[`). - - The Document Title in the citation must retain its original language. - - Output each citation on an individual line - - Provide maximum of 5 most relevant citations. - - Do not generate footnotes section or any comment, summary, or explanation after the references. +4. References Section Format (MANDATORY - MUST ALWAYS INCLUDE): + - **ALWAYS** end your response with a `### References` section - THIS IS REQUIRED + - **ONLY USE REFERENCES FROM THE PROVIDED REFERENCE DOCUMENT LIST** - Do not create or invent any references + - **COPY the citations EXACTLY as shown in the `Reference Document List`** - including any page numbers like "(p. 5)" or "(pp. 3-5)" + - **CRITICAL**: When you see page numbers in citations like "(p. 5)" or "(pp. 3-5)", you MUST include them in your References section + - Reference list entries should adhere to the format: `- [n] Document Title (p. X)` or `- [n] Document Title (pp. X-Y)` + - Do not include a caret (`^`) after opening square bracket (`[`) + - The Document Title in the citation must retain its original language + - Output each citation on an individual line starting with a dash (`-`) + - Provide maximum of 5 most relevant citations from the Reference Document List + - **NEVER** omit page numbers if they are present in the Reference Document List + - **NEVER** create references that are not in the Reference Document List + - Do not generate footnotes section or any comment, summary, or explanation after the references -5. Reference Section Example: +5. Reference Section Example (copy the format shown in Reference Document List): ``` ### References -- [1] Document Title One -- [2] Document Title Two -- [3] Document Title Three +- [1] document.pdf (pp. 1-3) +- [2] report.pdf (p. 5) +- [3] notes.txt ``` 6. Additional Instructions: {user_prompt} diff --git a/lightrag/utils.py b/lightrag/utils.py index 60542e43..610ba4aa 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -2817,6 +2817,13 @@ def convert_to_user_format( "file_path": chunk.get("file_path", "unknown_source"), "chunk_id": chunk.get("chunk_id", ""), } + # Add page metadata if available + if chunk.get("start_page") is not None: + chunk_data["start_page"] = chunk.get("start_page") + if chunk.get("end_page") is not None: + chunk_data["end_page"] = chunk.get("end_page") + if chunk.get("pages") is not None: + chunk_data["pages"] = chunk.get("pages") formatted_chunks.append(chunk_data) logger.debug( @@ -2845,67 +2852,73 @@ def convert_to_user_format( } -def generate_reference_list_from_chunks( - chunks: list[dict], -) -> tuple[list[dict], list[dict]]: - """ - Generate reference list from chunks, prioritizing by occurrence frequency. - - This function extracts file_paths from chunks, counts their occurrences, - sorts by frequency and first appearance order, creates reference_id mappings, - and builds a reference_list structure. - - Args: - chunks: List of chunk dictionaries with file_path information - - Returns: - tuple: (reference_list, updated_chunks_with_reference_ids) - - reference_list: List of dicts with reference_id and file_path - - updated_chunks_with_reference_ids: Original chunks with reference_id field added - """ +def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict], list[dict]]: + """Generate reference list from chunks, showing exact chunk page ranges.""" if not chunks: return [], [] - # 1. Extract all valid file_paths and count their occurrences - file_path_counts = {} - for chunk in chunks: - file_path = chunk.get("file_path", "") - if file_path and file_path != "unknown_source": - file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1 + def _create_chunk_references(chunks: list[dict]) -> tuple[list[dict], dict[str, str]]: + """Create references based on actual chunk page ranges instead of file aggregation.""" + chunk_ref_map = {} # Maps (file_path, page_range) -> reference_id + references = [] + ref_id_counter = 1 + + for chunk in chunks: + file_path = chunk.get("file_path", "") + if file_path == "unknown_source": + continue + + # Get page data for this specific chunk + chunk_pages = chunk.get("pages") + if chunk_pages and isinstance(chunk_pages, list): + # Create a unique key for this file + page range combination + page_range_key = (file_path, tuple(sorted(chunk_pages))) + + if page_range_key not in chunk_ref_map: + # Create new reference for this file + page range + chunk_ref_map[page_range_key] = str(ref_id_counter) + + # Build page range display + sorted_pages = sorted(chunk_pages) + if len(sorted_pages) == 1: + page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[0]} + else: + page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[-1]} + + references.append({ + "reference_id": str(ref_id_counter), + "file_path": file_path, + **page_display + }) + ref_id_counter += 1 + + return references, {f"{file_path}_{'-'.join(map(str, pages))}": ref_id + for (file_path, pages), ref_id in chunk_ref_map.items()} - # 2. Sort file paths by frequency (descending), then by first appearance order - # Create a list of (file_path, count, first_index) tuples - file_path_with_indices = [] - seen_paths = set() - for i, chunk in enumerate(chunks): - file_path = chunk.get("file_path", "") - if file_path and file_path != "unknown_source" and file_path not in seen_paths: - file_path_with_indices.append((file_path, file_path_counts[file_path], i)) - seen_paths.add(file_path) - - # Sort by count (descending), then by first appearance index (ascending) - sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2])) - unique_file_paths = [item[0] for item in sorted_file_paths] - - # 3. Create mapping from file_path to reference_id (prioritized by frequency) - file_path_to_ref_id = {} - for i, file_path in enumerate(unique_file_paths): - file_path_to_ref_id[file_path] = str(i + 1) - - # 4. Add reference_id field to each chunk - updated_chunks = [] - for chunk in chunks: - chunk_copy = chunk.copy() - file_path = chunk_copy.get("file_path", "") - if file_path and file_path != "unknown_source": - chunk_copy["reference_id"] = file_path_to_ref_id[file_path] - else: - chunk_copy["reference_id"] = "" - updated_chunks.append(chunk_copy) - - # 5. Build reference_list - reference_list = [] - for i, file_path in enumerate(unique_file_paths): - reference_list.append({"reference_id": str(i + 1), "file_path": file_path}) + def _add_reference_ids_to_chunks(chunks: list[dict], chunk_ref_map: dict[str, str]) -> list[dict]: + """Add reference_id field to chunks based on their specific page ranges.""" + updated = [] + for chunk in chunks: + chunk_copy = chunk.copy() + file_path = chunk_copy.get("file_path", "") + + if file_path != "unknown_source": + chunk_pages = chunk_copy.get("pages") + if chunk_pages and isinstance(chunk_pages, list): + # Create the same key used in reference creation + page_key = f"{file_path}_{'-'.join(map(str, sorted(chunk_pages)))}" + chunk_copy["reference_id"] = chunk_ref_map.get(page_key, "") + else: + # Fallback: find any reference for this file (no page data) + chunk_copy["reference_id"] = "" + else: + chunk_copy["reference_id"] = "" + + updated.append(chunk_copy) + return updated + # Main execution flow + reference_list, chunk_ref_map = _create_chunk_references(chunks) + updated_chunks = _add_reference_ids_to_chunks(chunks, chunk_ref_map) + return reference_list, updated_chunks From 7864a75bdac61f920533c411d88c267adc7cf90a Mon Sep 17 00:00:00 2001 From: Saswat Date: Fri, 10 Oct 2025 13:03:09 +0530 Subject: [PATCH 2/2] code formatting --- lightrag/api/routers/document_routes.py | 84 +++++---- lightrag/kg/postgres_impl.py | 12 +- lightrag/lightrag.py | 29 ++- lightrag/operate.py | 233 ++++++++++++++---------- lightrag/utils.py | 62 ++++--- 5 files changed, 260 insertions(+), 160 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 5dd8f651..0c45f8ac 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -864,7 +864,7 @@ async def _extract_pdf_with_docling(file_path: Path) -> str: if not pm.is_installed("docling"): # type: ignore pm.install("docling") from docling.document_converter import DocumentConverter # type: ignore - + converter = DocumentConverter() result = converter.convert(file_path) return result.document.export_to_markdown() @@ -876,47 +876,51 @@ async def _extract_pdf_with_pypdf2(file_bytes: bytes) -> tuple[str, list[dict]]: pm.install("pypdf2") from PyPDF2 import PdfReader # type: ignore from io import BytesIO - + pdf_file = BytesIO(file_bytes) reader = PdfReader(pdf_file) content = "" page_data = [] char_position = 0 - + for page_num, page in enumerate(reader.pages, start=1): page_text = page.extract_text() + "\n" page_start = char_position page_end = char_position + len(page_text) - - page_data.append({ - "page_number": page_num, - "content": page_text, - "char_start": page_start, - "char_end": page_end, - }) - + + page_data.append( + { + "page_number": page_num, + "content": page_text, + "char_start": page_start, + "char_end": page_end, + } + ) + content += page_text char_position = page_end - + return content, page_data async def _handle_file_processing_error( - rag: LightRAG, - filename: str, - error_type: str, - error_msg: str, - file_size: int, - track_id: str + rag: LightRAG, + filename: str, + error_type: str, + error_msg: str, + file_size: int, + track_id: str, ) -> None: """Handle file processing errors consistently.""" - error_files = [{ - "file_path": filename, - "error_description": f"[File Extraction]{error_type}", - "original_error": error_msg, - "file_size": file_size, - }] - + error_files = [ + { + "file_path": filename, + "error_description": f"[File Extraction]{error_type}", + "original_error": error_msg, + "file_size": file_size, + } + ] + await rag.apipeline_enqueue_error_documents(error_files, track_id) logger.error(f"[File Extraction]{error_type} for {filename}: {error_msg}") @@ -1100,7 +1104,12 @@ async def pipeline_enqueue_file( content, page_data = await _extract_pdf_with_pypdf2(file) except Exception as e: await _handle_file_processing_error( - rag, file_path.name, "PDF processing error", str(e), file_size, track_id + rag, + file_path.name, + "PDF processing error", + str(e), + file_size, + track_id, ) return False, track_id @@ -1280,16 +1289,27 @@ async def pipeline_enqueue_file( try: # Pass page_data if it was collected (only for PDFs with PyPDF2) - page_data_to_pass = [page_data] if page_data is not None and len(page_data) > 0 else None - + page_data_to_pass = ( + [page_data] + if page_data is not None and len(page_data) > 0 + else None + ) + # Debug logging if page_data_to_pass: - logger.info(f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages") + logger.info( + f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages" + ) else: - logger.debug(f"No page metadata for {file_path.name} (non-PDF or extraction failed)") - + logger.debug( + f"No page metadata for {file_path.name} (non-PDF or extraction failed)" + ) + await rag.apipeline_enqueue_documents( - content, file_paths=file_path.name, track_id=track_id, page_data_list=page_data_to_pass + content, + file_paths=file_path.name, + track_id=track_id, + page_data_list=page_data_to_pass, ) logger.info( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index e3f333db..6459fb91 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1784,7 +1784,9 @@ class PGKVStorage(BaseKVStorage): "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), "start_page": v.get("start_page"), # Optional page fields "end_page": v.get("end_page"), - "pages": json.dumps(v.get("pages")) if v.get("pages") is not None else None, + "pages": json.dumps(v.get("pages")) + if v.get("pages") is not None + else None, "create_time": current_time, "update_time": current_time, } @@ -1797,7 +1799,9 @@ class PGKVStorage(BaseKVStorage): "content": v["content"], "doc_name": v.get("file_path", ""), # Map file_path to doc_name "workspace": self.workspace, - "page_data": json.dumps(v.get("page_data")) if v.get("page_data") is not None else None, + "page_data": json.dumps(v.get("page_data")) + if v.get("page_data") is not None + else None, } await self.db.execute(upsert_sql, _data) elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -1955,7 +1959,9 @@ class PGVectorStorage(BaseVectorStorage): "file_path": item["file_path"], "start_page": item.get("start_page"), # Optional page fields "end_page": item.get("end_page"), - "pages": json.dumps(item.get("pages")) if item.get("pages") is not None else None, + "pages": json.dumps(item.get("pages")) + if item.get("pages") is not None + else None, "create_time": current_time, "update_time": current_time, } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 18c793ae..542dcf42 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -550,7 +550,14 @@ class LightRAG: namespace=NameSpace.VECTOR_STORE_CHUNKS, workspace=self.workspace, embedding_func=self.embedding_func, - meta_fields={"full_doc_id", "content", "file_path", "start_page", "end_page", "pages"}, + meta_fields={ + "full_doc_id", + "content", + "file_path", + "start_page", + "end_page", + "pages", + }, ) # Initialize document status storage @@ -1053,7 +1060,7 @@ class LightRAG: else: # If no file paths provided, use placeholder file_paths = ["unknown_source"] * len(input) - + # Handle page_data_list if page_data_list is not None: if len(page_data_list) != len(input): @@ -1076,14 +1083,20 @@ class LightRAG: # Generate contents dict and remove duplicates in one pass unique_contents = {} - for id_, doc, path, page_data in zip(ids, input, file_paths, page_data_list): + for id_, doc, path, page_data in zip( + ids, input, file_paths, page_data_list + ): cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_contents: unique_contents[cleaned_content] = (id_, path, page_data) # Reconstruct contents with unique content contents = { - id_: {"content": content, "file_path": file_path, "page_data": page_data} + id_: { + "content": content, + "file_path": file_path, + "page_data": page_data, + } for content, (id_, file_path, page_data) in unique_contents.items() } else: @@ -1156,7 +1169,9 @@ class LightRAG: doc_id: { "content": contents[doc_id]["content"], "file_path": contents[doc_id]["file_path"], - "page_data": contents[doc_id].get("page_data"), # Optional page metadata + "page_data": contents[doc_id].get( + "page_data" + ), # Optional page metadata } for doc_id in new_docs.keys() } @@ -1540,7 +1555,9 @@ class LightRAG: f"Document content not found in full_docs for doc_id: {doc_id}" ) content = content_data["content"] - page_data = content_data.get("page_data") # Optional page metadata + page_data = content_data.get( + "page_data" + ) # Optional page metadata # Generate chunks from document chunks: dict[str, Any] = { diff --git a/lightrag/operate.py b/lightrag/operate.py index 34981c06..97148a18 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -66,32 +66,34 @@ load_dotenv(dotenv_path=".env", override=False) def validate_llm_references(response: str, valid_ref_ids: set[str]) -> tuple[str, bool]: """ Validate that LLM response only uses valid reference IDs. - + Args: response: The LLM response text valid_ref_ids: Set of valid reference IDs from the reference list - + Returns: Tuple of (cleaned_response, is_valid) """ import re - + # Find all reference patterns like [1], [2], etc. - ref_pattern = r'\[(\d+)\]' + ref_pattern = r"\[(\d+)\]" matches = re.findall(ref_pattern, response) - + invalid_refs = [] for ref_id in matches: if ref_id not in valid_ref_ids: invalid_refs.append(ref_id) - + if invalid_refs: - logger.warning(f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}") + logger.warning( + f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}" + ) # Remove invalid references from the response for invalid_ref in invalid_refs: - response = re.sub(rf'\[{invalid_ref}\](?:\s*\([^)]*\))?', '', response) + response = re.sub(rf"\[{invalid_ref}\](?:\s*\([^)]*\))?", "", response) return response, False - + return response, True @@ -105,99 +107,118 @@ def chunking_by_token_size( page_data: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: """Chunk content by token size with optional page tracking.""" - - def _calculate_page_range(start_char: int, end_char: int) -> tuple[int | None, int | None, list[int] | None]: + + def _calculate_page_range( + start_char: int, end_char: int + ) -> tuple[int | None, int | None, list[int] | None]: if not page_data: return None, None, None - + pages = set() start_page = end_page = None - + for page in page_data: page_num = page["page_number"] page_start = page["char_start"] page_end = page["char_end"] - + if start_char < page_end and end_char > page_start: pages.add(page_num) start_page = min(start_page, page_num) if start_page else page_num end_page = max(end_page, page_num) if end_page else page_num - + return start_page, end_page, sorted(pages) if pages else None - - def _estimate_char_positions(token_start: int, token_end: int, total_tokens: int, total_chars: int) -> tuple[int, int]: + + def _estimate_char_positions( + token_start: int, token_end: int, total_tokens: int, total_chars: int + ) -> tuple[int, int]: if total_tokens == 0: return 0, total_chars start_char = int((token_start / total_tokens) * total_chars) end_char = int((token_end / total_tokens) * total_chars) return start_char, end_char - - def _create_chunk_dict(token_count: int, content: str, index: int, start_char: int, end_char: int) -> dict[str, Any]: + + def _create_chunk_dict( + token_count: int, content: str, index: int, start_char: int, end_char: int + ) -> dict[str, Any]: chunk = { "tokens": token_count, "content": content.strip(), "chunk_order_index": index, } - + if page_data: start_page, end_page, pages = _calculate_page_range(start_char, end_char) - chunk.update({ - "start_page": start_page, - "end_page": end_page, - "pages": pages - }) - + chunk.update( + {"start_page": start_page, "end_page": end_page, "pages": pages} + ) + return chunk - + tokens = tokenizer.encode(content) total_tokens = len(tokens) total_chars = len(content) results = [] - + if split_by_character: raw_chunks = content.split(split_by_character) chunks_with_positions = [] char_pos = 0 - + for chunk_text in raw_chunks: chunk_tokens = tokenizer.encode(chunk_text) chunk_start = char_pos chunk_end = char_pos + len(chunk_text) - + if split_by_character_only or len(chunk_tokens) <= max_token_size: - chunks_with_positions.append((len(chunk_tokens), chunk_text, chunk_start, chunk_end)) + chunks_with_positions.append( + (len(chunk_tokens), chunk_text, chunk_start, chunk_end) + ) else: # Split large chunks by tokens - for token_start in range(0, len(chunk_tokens), max_token_size - overlap_token_size): + for token_start in range( + 0, len(chunk_tokens), max_token_size - overlap_token_size + ): token_end = min(token_start + max_token_size, len(chunk_tokens)) - chunk_content = tokenizer.decode(chunk_tokens[token_start:token_end]) - + chunk_content = tokenizer.decode( + chunk_tokens[token_start:token_end] + ) + # Estimate character positions within the chunk ratio_start = token_start / len(chunk_tokens) ratio_end = token_end / len(chunk_tokens) sub_start = chunk_start + int(len(chunk_text) * ratio_start) sub_end = chunk_start + int(len(chunk_text) * ratio_end) - - chunks_with_positions.append(( - token_end - token_start, - chunk_content, - sub_start, - sub_end - )) - + + chunks_with_positions.append( + (token_end - token_start, chunk_content, sub_start, sub_end) + ) + char_pos = chunk_end + len(split_by_character) - - for index, (token_count, chunk_text, start_char, end_char) in enumerate(chunks_with_positions): - results.append(_create_chunk_dict(token_count, chunk_text, index, start_char, end_char)) + + for index, (token_count, chunk_text, start_char, end_char) in enumerate( + chunks_with_positions + ): + results.append( + _create_chunk_dict(token_count, chunk_text, index, start_char, end_char) + ) else: # Token-based chunking - for index, token_start in enumerate(range(0, total_tokens, max_token_size - overlap_token_size)): + for index, token_start in enumerate( + range(0, total_tokens, max_token_size - overlap_token_size) + ): token_end = min(token_start + max_token_size, total_tokens) chunk_content = tokenizer.decode(tokens[token_start:token_end]) - start_char, end_char = _estimate_char_positions(token_start, token_end, total_tokens, total_chars) - - results.append(_create_chunk_dict(token_end - token_start, chunk_content, index, start_char, end_char)) - + start_char, end_char = _estimate_char_positions( + token_start, token_end, total_tokens, total_chars + ) + + results.append( + _create_chunk_dict( + token_end - token_start, chunk_content, index, start_char, end_char + ) + ) + return results @@ -2467,13 +2488,15 @@ async def kg_query( " == LLM cache == Query cache hit, using cached response as query result" ) response = cached_response - + # Validate references in cached response too - valid_ref_ids = global_config.get('_valid_reference_ids', set()) + valid_ref_ids = global_config.get("_valid_reference_ids", set()) if valid_ref_ids: response, is_valid = validate_llm_references(response, valid_ref_ids) if not is_valid: - logger.warning("Cached LLM response contained invalid references and has been cleaned") + logger.warning( + "Cached LLM response contained invalid references and has been cleaned" + ) else: response = await use_model_func( user_query, @@ -2482,13 +2505,15 @@ async def kg_query( enable_cot=True, stream=query_param.stream, ) - + # Validate references in the response - valid_ref_ids = global_config.get('_valid_reference_ids', set()) + valid_ref_ids = global_config.get("_valid_reference_ids", set()) if valid_ref_ids: response, is_valid = validate_llm_references(response, valid_ref_ids) if not is_valid: - logger.warning("LLM response contained invalid references and has been cleaned") + logger.warning( + "LLM response contained invalid references and has been cleaned" + ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { @@ -3136,12 +3161,12 @@ async def _merge_all_chunks( "file_path": chunk.get("file_path", "unknown_source"), "chunk_id": chunk_id, } - + # Preserve page metadata if available for field in ["start_page", "end_page", "pages"]: if chunk.get(field) is not None: metadata[field] = chunk.get(field) - + return metadata def _merge_chunks_round_robin(chunk_sources: list[list[dict]]) -> list[dict]: @@ -3150,18 +3175,20 @@ async def _merge_all_chunks( seen_ids = set() max_len = max(len(source) for source in chunk_sources) total_original = sum(len(source) for source in chunk_sources) - + for i in range(max_len): for source in chunk_sources: if i < len(source): chunk = source[i] chunk_id = chunk.get("chunk_id") or chunk.get("id") - + if chunk_id and chunk_id not in seen_ids: seen_ids.add(chunk_id) merged.append(_extract_chunk_metadata(chunk)) - - logger.info(f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})") + + logger.info( + f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})" + ) return merged return _merge_chunks_round_robin([vector_chunks, entity_chunks, relation_chunks]) @@ -3267,8 +3294,10 @@ async def _build_llm_context( if truncated_chunks: sample_chunk = truncated_chunks[0] has_pages = "pages" in sample_chunk - logger.info(f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}") - + logger.info( + f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}" + ) + reference_list, truncated_chunks = generate_reference_list_from_chunks( truncated_chunks ) @@ -3323,18 +3352,18 @@ async def _build_llm_context( text_units_str = "\n".join( json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context ) - + # Format reference list with page numbers if available formatted_references = [] for ref in reference_list: if not ref.get("reference_id"): continue - - file_path = ref['file_path'] - ref_id = ref['reference_id'] - + + file_path = ref["file_path"] + ref_id = ref["reference_id"] + # Add page numbers if available - pages = ref.get('pages') + pages = ref.get("pages") if pages and len(pages) > 0: if len(pages) == 1: # Single page: "document.pdf (p. 5)" @@ -3345,19 +3374,21 @@ async def _build_llm_context( else: # No page info: "document.txt" citation = f"[{ref_id}] {file_path}" - + formatted_references.append(citation) - + reference_list_str = "\n".join(formatted_references) - + # Debug: Log what references are being sent to the LLM logger.info(f"Reference list for LLM ({len(formatted_references)} refs):") for ref_line in formatted_references[:3]: # Show first 3 logger.info(f" {ref_line}") - + # Store valid reference IDs for validation - valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')} - global_config['_valid_reference_ids'] = valid_ref_ids + valid_ref_ids = { + ref["reference_id"] for ref in reference_list if ref.get("reference_id") + } + global_config["_valid_reference_ids"] = valid_ref_ids result = kg_context_template.format( entities_str=entities_str, @@ -3766,12 +3797,14 @@ async def _find_related_text_unit_from_entities( chunk_data_copy = chunk_data.copy() chunk_data_copy["source_type"] = "entity" chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication - + # Debug: Check if page metadata is present if i == 0: # Log first chunk only has_pages = "pages" in chunk_data_copy - logger.info(f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}") - + logger.info( + f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}" + ) + result_chunks.append(chunk_data_copy) # Update chunk tracking if provided @@ -4250,18 +4283,18 @@ async def naive_query( text_units_str = "\n".join( json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context ) - + # Format reference list with page numbers if available formatted_references = [] for ref in reference_list: if not ref.get("reference_id"): continue - - file_path = ref['file_path'] - ref_id = ref['reference_id'] - + + file_path = ref["file_path"] + ref_id = ref["reference_id"] + # Add page numbers if available - pages = ref.get('pages') + pages = ref.get("pages") if pages and len(pages) > 0: if len(pages) == 1: # Single page: "document.pdf (p. 5)" @@ -4272,19 +4305,21 @@ async def naive_query( else: # No page info: "document.txt" citation = f"[{ref_id}] {file_path}" - + formatted_references.append(citation) - + reference_list_str = "\n".join(formatted_references) - + # Debug: Log what references are being sent to the LLM logger.info(f"Reference list for LLM ({len(formatted_references)} refs):") for ref_line in formatted_references[:3]: # Show first 3 logger.info(f" {ref_line}") - + # Store valid reference IDs for validation - valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')} - global_config['_valid_reference_ids'] = valid_ref_ids + valid_ref_ids = { + ref["reference_id"] for ref in reference_list if ref.get("reference_id") + } + global_config["_valid_reference_ids"] = valid_ref_ids naive_context_template = PROMPTS["naive_query_context"] context_content = naive_context_template.format( @@ -4329,13 +4364,15 @@ async def naive_query( " == LLM cache == Query cache hit, using cached response as query result" ) response = cached_response - + # Validate references in cached response too - valid_ref_ids = global_config.get('_valid_reference_ids', set()) + valid_ref_ids = global_config.get("_valid_reference_ids", set()) if valid_ref_ids: response, is_valid = validate_llm_references(response, valid_ref_ids) if not is_valid: - logger.warning("Cached LLM response contained invalid references and has been cleaned") + logger.warning( + "Cached LLM response contained invalid references and has been cleaned" + ) else: response = await use_model_func( user_query, @@ -4344,13 +4381,15 @@ async def naive_query( enable_cot=True, stream=query_param.stream, ) - + # Validate references in the response - valid_ref_ids = global_config.get('_valid_reference_ids', set()) + valid_ref_ids = global_config.get("_valid_reference_ids", set()) if valid_ref_ids: response, is_valid = validate_llm_references(response, valid_ref_ids) if not is_valid: - logger.warning("LLM response contained invalid references and has been cleaned") + logger.warning( + "LLM response contained invalid references and has been cleaned" + ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { diff --git a/lightrag/utils.py b/lightrag/utils.py index 610ba4aa..288c80a6 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -2852,56 +2852,74 @@ def convert_to_user_format( } -def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict], list[dict]]: +def generate_reference_list_from_chunks( + chunks: list[dict], +) -> tuple[list[dict], list[dict]]: """Generate reference list from chunks, showing exact chunk page ranges.""" if not chunks: return [], [] - def _create_chunk_references(chunks: list[dict]) -> tuple[list[dict], dict[str, str]]: + def _create_chunk_references( + chunks: list[dict], + ) -> tuple[list[dict], dict[str, str]]: """Create references based on actual chunk page ranges instead of file aggregation.""" chunk_ref_map = {} # Maps (file_path, page_range) -> reference_id references = [] ref_id_counter = 1 - + for chunk in chunks: file_path = chunk.get("file_path", "") if file_path == "unknown_source": continue - + # Get page data for this specific chunk chunk_pages = chunk.get("pages") if chunk_pages and isinstance(chunk_pages, list): # Create a unique key for this file + page range combination page_range_key = (file_path, tuple(sorted(chunk_pages))) - + if page_range_key not in chunk_ref_map: # Create new reference for this file + page range chunk_ref_map[page_range_key] = str(ref_id_counter) - + # Build page range display sorted_pages = sorted(chunk_pages) if len(sorted_pages) == 1: - page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[0]} + page_display = { + "pages": sorted_pages, + "start_page": sorted_pages[0], + "end_page": sorted_pages[0], + } else: - page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[-1]} - - references.append({ - "reference_id": str(ref_id_counter), - "file_path": file_path, - **page_display - }) - ref_id_counter += 1 - - return references, {f"{file_path}_{'-'.join(map(str, pages))}": ref_id - for (file_path, pages), ref_id in chunk_ref_map.items()} + page_display = { + "pages": sorted_pages, + "start_page": sorted_pages[0], + "end_page": sorted_pages[-1], + } - def _add_reference_ids_to_chunks(chunks: list[dict], chunk_ref_map: dict[str, str]) -> list[dict]: + references.append( + { + "reference_id": str(ref_id_counter), + "file_path": file_path, + **page_display, + } + ) + ref_id_counter += 1 + + return references, { + f"{file_path}_{'-'.join(map(str, pages))}": ref_id + for (file_path, pages), ref_id in chunk_ref_map.items() + } + + def _add_reference_ids_to_chunks( + chunks: list[dict], chunk_ref_map: dict[str, str] + ) -> list[dict]: """Add reference_id field to chunks based on their specific page ranges.""" updated = [] for chunk in chunks: chunk_copy = chunk.copy() file_path = chunk_copy.get("file_path", "") - + if file_path != "unknown_source": chunk_pages = chunk_copy.get("pages") if chunk_pages and isinstance(chunk_pages, list): @@ -2913,12 +2931,12 @@ def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict], chunk_copy["reference_id"] = "" else: chunk_copy["reference_id"] = "" - + updated.append(chunk_copy) return updated # Main execution flow reference_list, chunk_ref_map = _create_chunk_references(chunks) updated_chunks = _add_reference_ids_to_chunks(chunks, chunk_ref_map) - + return reference_list, updated_chunks