code formatting

This commit is contained in:
Saswat 2025-10-10 13:03:09 +05:30
parent 6872f085d1
commit 7864a75bda
5 changed files with 260 additions and 160 deletions

View file

@ -864,7 +864,7 @@ async def _extract_pdf_with_docling(file_path: Path) -> str:
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
return result.document.export_to_markdown()
@ -876,47 +876,51 @@ async def _extract_pdf_with_pypdf2(file_bytes: bytes) -> tuple[str, list[dict]]:
pm.install("pypdf2")
from PyPDF2 import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file_bytes)
reader = PdfReader(pdf_file)
content = ""
page_data = []
char_position = 0
for page_num, page in enumerate(reader.pages, start=1):
page_text = page.extract_text() + "\n"
page_start = char_position
page_end = char_position + len(page_text)
page_data.append({
"page_number": page_num,
"content": page_text,
"char_start": page_start,
"char_end": page_end,
})
page_data.append(
{
"page_number": page_num,
"content": page_text,
"char_start": page_start,
"char_end": page_end,
}
)
content += page_text
char_position = page_end
return content, page_data
async def _handle_file_processing_error(
rag: LightRAG,
filename: str,
error_type: str,
error_msg: str,
file_size: int,
track_id: str
rag: LightRAG,
filename: str,
error_type: str,
error_msg: str,
file_size: int,
track_id: str,
) -> None:
"""Handle file processing errors consistently."""
error_files = [{
"file_path": filename,
"error_description": f"[File Extraction]{error_type}",
"original_error": error_msg,
"file_size": file_size,
}]
error_files = [
{
"file_path": filename,
"error_description": f"[File Extraction]{error_type}",
"original_error": error_msg,
"file_size": file_size,
}
]
await rag.apipeline_enqueue_error_documents(error_files, track_id)
logger.error(f"[File Extraction]{error_type} for {filename}: {error_msg}")
@ -1100,7 +1104,12 @@ async def pipeline_enqueue_file(
content, page_data = await _extract_pdf_with_pypdf2(file)
except Exception as e:
await _handle_file_processing_error(
rag, file_path.name, "PDF processing error", str(e), file_size, track_id
rag,
file_path.name,
"PDF processing error",
str(e),
file_size,
track_id,
)
return False, track_id
@ -1280,16 +1289,27 @@ async def pipeline_enqueue_file(
try:
# Pass page_data if it was collected (only for PDFs with PyPDF2)
page_data_to_pass = [page_data] if page_data is not None and len(page_data) > 0 else None
page_data_to_pass = (
[page_data]
if page_data is not None and len(page_data) > 0
else None
)
# Debug logging
if page_data_to_pass:
logger.info(f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages")
logger.info(
f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages"
)
else:
logger.debug(f"No page metadata for {file_path.name} (non-PDF or extraction failed)")
logger.debug(
f"No page metadata for {file_path.name} (non-PDF or extraction failed)"
)
await rag.apipeline_enqueue_documents(
content, file_paths=file_path.name, track_id=track_id, page_data_list=page_data_to_pass
content,
file_paths=file_path.name,
track_id=track_id,
page_data_list=page_data_to_pass,
)
logger.info(

View file

@ -1784,7 +1784,9 @@ class PGKVStorage(BaseKVStorage):
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
"start_page": v.get("start_page"), # Optional page fields
"end_page": v.get("end_page"),
"pages": json.dumps(v.get("pages")) if v.get("pages") is not None else None,
"pages": json.dumps(v.get("pages"))
if v.get("pages") is not None
else None,
"create_time": current_time,
"update_time": current_time,
}
@ -1797,7 +1799,9 @@ class PGKVStorage(BaseKVStorage):
"content": v["content"],
"doc_name": v.get("file_path", ""), # Map file_path to doc_name
"workspace": self.workspace,
"page_data": json.dumps(v.get("page_data")) if v.get("page_data") is not None else None,
"page_data": json.dumps(v.get("page_data"))
if v.get("page_data") is not None
else None,
}
await self.db.execute(upsert_sql, _data)
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@ -1955,7 +1959,9 @@ class PGVectorStorage(BaseVectorStorage):
"file_path": item["file_path"],
"start_page": item.get("start_page"), # Optional page fields
"end_page": item.get("end_page"),
"pages": json.dumps(item.get("pages")) if item.get("pages") is not None else None,
"pages": json.dumps(item.get("pages"))
if item.get("pages") is not None
else None,
"create_time": current_time,
"update_time": current_time,
}

View file

@ -550,7 +550,14 @@ class LightRAG:
namespace=NameSpace.VECTOR_STORE_CHUNKS,
workspace=self.workspace,
embedding_func=self.embedding_func,
meta_fields={"full_doc_id", "content", "file_path", "start_page", "end_page", "pages"},
meta_fields={
"full_doc_id",
"content",
"file_path",
"start_page",
"end_page",
"pages",
},
)
# Initialize document status storage
@ -1053,7 +1060,7 @@ class LightRAG:
else:
# If no file paths provided, use placeholder
file_paths = ["unknown_source"] * len(input)
# Handle page_data_list
if page_data_list is not None:
if len(page_data_list) != len(input):
@ -1076,14 +1083,20 @@ class LightRAG:
# Generate contents dict and remove duplicates in one pass
unique_contents = {}
for id_, doc, path, page_data in zip(ids, input, file_paths, page_data_list):
for id_, doc, path, page_data in zip(
ids, input, file_paths, page_data_list
):
cleaned_content = sanitize_text_for_encoding(doc)
if cleaned_content not in unique_contents:
unique_contents[cleaned_content] = (id_, path, page_data)
# Reconstruct contents with unique content
contents = {
id_: {"content": content, "file_path": file_path, "page_data": page_data}
id_: {
"content": content,
"file_path": file_path,
"page_data": page_data,
}
for content, (id_, file_path, page_data) in unique_contents.items()
}
else:
@ -1156,7 +1169,9 @@ class LightRAG:
doc_id: {
"content": contents[doc_id]["content"],
"file_path": contents[doc_id]["file_path"],
"page_data": contents[doc_id].get("page_data"), # Optional page metadata
"page_data": contents[doc_id].get(
"page_data"
), # Optional page metadata
}
for doc_id in new_docs.keys()
}
@ -1540,7 +1555,9 @@ class LightRAG:
f"Document content not found in full_docs for doc_id: {doc_id}"
)
content = content_data["content"]
page_data = content_data.get("page_data") # Optional page metadata
page_data = content_data.get(
"page_data"
) # Optional page metadata
# Generate chunks from document
chunks: dict[str, Any] = {

View file

@ -66,32 +66,34 @@ load_dotenv(dotenv_path=".env", override=False)
def validate_llm_references(response: str, valid_ref_ids: set[str]) -> tuple[str, bool]:
"""
Validate that LLM response only uses valid reference IDs.
Args:
response: The LLM response text
valid_ref_ids: Set of valid reference IDs from the reference list
Returns:
Tuple of (cleaned_response, is_valid)
"""
import re
# Find all reference patterns like [1], [2], etc.
ref_pattern = r'\[(\d+)\]'
ref_pattern = r"\[(\d+)\]"
matches = re.findall(ref_pattern, response)
invalid_refs = []
for ref_id in matches:
if ref_id not in valid_ref_ids:
invalid_refs.append(ref_id)
if invalid_refs:
logger.warning(f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}")
logger.warning(
f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}"
)
# Remove invalid references from the response
for invalid_ref in invalid_refs:
response = re.sub(rf'\[{invalid_ref}\](?:\s*\([^)]*\))?', '', response)
response = re.sub(rf"\[{invalid_ref}\](?:\s*\([^)]*\))?", "", response)
return response, False
return response, True
@ -105,99 +107,118 @@ def chunking_by_token_size(
page_data: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Chunk content by token size with optional page tracking."""
def _calculate_page_range(start_char: int, end_char: int) -> tuple[int | None, int | None, list[int] | None]:
def _calculate_page_range(
start_char: int, end_char: int
) -> tuple[int | None, int | None, list[int] | None]:
if not page_data:
return None, None, None
pages = set()
start_page = end_page = None
for page in page_data:
page_num = page["page_number"]
page_start = page["char_start"]
page_end = page["char_end"]
if start_char < page_end and end_char > page_start:
pages.add(page_num)
start_page = min(start_page, page_num) if start_page else page_num
end_page = max(end_page, page_num) if end_page else page_num
return start_page, end_page, sorted(pages) if pages else None
def _estimate_char_positions(token_start: int, token_end: int, total_tokens: int, total_chars: int) -> tuple[int, int]:
def _estimate_char_positions(
token_start: int, token_end: int, total_tokens: int, total_chars: int
) -> tuple[int, int]:
if total_tokens == 0:
return 0, total_chars
start_char = int((token_start / total_tokens) * total_chars)
end_char = int((token_end / total_tokens) * total_chars)
return start_char, end_char
def _create_chunk_dict(token_count: int, content: str, index: int, start_char: int, end_char: int) -> dict[str, Any]:
def _create_chunk_dict(
token_count: int, content: str, index: int, start_char: int, end_char: int
) -> dict[str, Any]:
chunk = {
"tokens": token_count,
"content": content.strip(),
"chunk_order_index": index,
}
if page_data:
start_page, end_page, pages = _calculate_page_range(start_char, end_char)
chunk.update({
"start_page": start_page,
"end_page": end_page,
"pages": pages
})
chunk.update(
{"start_page": start_page, "end_page": end_page, "pages": pages}
)
return chunk
tokens = tokenizer.encode(content)
total_tokens = len(tokens)
total_chars = len(content)
results = []
if split_by_character:
raw_chunks = content.split(split_by_character)
chunks_with_positions = []
char_pos = 0
for chunk_text in raw_chunks:
chunk_tokens = tokenizer.encode(chunk_text)
chunk_start = char_pos
chunk_end = char_pos + len(chunk_text)
if split_by_character_only or len(chunk_tokens) <= max_token_size:
chunks_with_positions.append((len(chunk_tokens), chunk_text, chunk_start, chunk_end))
chunks_with_positions.append(
(len(chunk_tokens), chunk_text, chunk_start, chunk_end)
)
else:
# Split large chunks by tokens
for token_start in range(0, len(chunk_tokens), max_token_size - overlap_token_size):
for token_start in range(
0, len(chunk_tokens), max_token_size - overlap_token_size
):
token_end = min(token_start + max_token_size, len(chunk_tokens))
chunk_content = tokenizer.decode(chunk_tokens[token_start:token_end])
chunk_content = tokenizer.decode(
chunk_tokens[token_start:token_end]
)
# Estimate character positions within the chunk
ratio_start = token_start / len(chunk_tokens)
ratio_end = token_end / len(chunk_tokens)
sub_start = chunk_start + int(len(chunk_text) * ratio_start)
sub_end = chunk_start + int(len(chunk_text) * ratio_end)
chunks_with_positions.append((
token_end - token_start,
chunk_content,
sub_start,
sub_end
))
chunks_with_positions.append(
(token_end - token_start, chunk_content, sub_start, sub_end)
)
char_pos = chunk_end + len(split_by_character)
for index, (token_count, chunk_text, start_char, end_char) in enumerate(chunks_with_positions):
results.append(_create_chunk_dict(token_count, chunk_text, index, start_char, end_char))
for index, (token_count, chunk_text, start_char, end_char) in enumerate(
chunks_with_positions
):
results.append(
_create_chunk_dict(token_count, chunk_text, index, start_char, end_char)
)
else:
# Token-based chunking
for index, token_start in enumerate(range(0, total_tokens, max_token_size - overlap_token_size)):
for index, token_start in enumerate(
range(0, total_tokens, max_token_size - overlap_token_size)
):
token_end = min(token_start + max_token_size, total_tokens)
chunk_content = tokenizer.decode(tokens[token_start:token_end])
start_char, end_char = _estimate_char_positions(token_start, token_end, total_tokens, total_chars)
results.append(_create_chunk_dict(token_end - token_start, chunk_content, index, start_char, end_char))
start_char, end_char = _estimate_char_positions(
token_start, token_end, total_tokens, total_chars
)
results.append(
_create_chunk_dict(
token_end - token_start, chunk_content, index, start_char, end_char
)
)
return results
@ -2467,13 +2488,15 @@ async def kg_query(
" == LLM cache == Query cache hit, using cached response as query result"
)
response = cached_response
# Validate references in cached response too
valid_ref_ids = global_config.get('_valid_reference_ids', set())
valid_ref_ids = global_config.get("_valid_reference_ids", set())
if valid_ref_ids:
response, is_valid = validate_llm_references(response, valid_ref_ids)
if not is_valid:
logger.warning("Cached LLM response contained invalid references and has been cleaned")
logger.warning(
"Cached LLM response contained invalid references and has been cleaned"
)
else:
response = await use_model_func(
user_query,
@ -2482,13 +2505,15 @@ async def kg_query(
enable_cot=True,
stream=query_param.stream,
)
# Validate references in the response
valid_ref_ids = global_config.get('_valid_reference_ids', set())
valid_ref_ids = global_config.get("_valid_reference_ids", set())
if valid_ref_ids:
response, is_valid = validate_llm_references(response, valid_ref_ids)
if not is_valid:
logger.warning("LLM response contained invalid references and has been cleaned")
logger.warning(
"LLM response contained invalid references and has been cleaned"
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
@ -3136,12 +3161,12 @@ async def _merge_all_chunks(
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk_id,
}
# Preserve page metadata if available
for field in ["start_page", "end_page", "pages"]:
if chunk.get(field) is not None:
metadata[field] = chunk.get(field)
return metadata
def _merge_chunks_round_robin(chunk_sources: list[list[dict]]) -> list[dict]:
@ -3150,18 +3175,20 @@ async def _merge_all_chunks(
seen_ids = set()
max_len = max(len(source) for source in chunk_sources)
total_original = sum(len(source) for source in chunk_sources)
for i in range(max_len):
for source in chunk_sources:
if i < len(source):
chunk = source[i]
chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id and chunk_id not in seen_ids:
seen_ids.add(chunk_id)
merged.append(_extract_chunk_metadata(chunk))
logger.info(f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})")
logger.info(
f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})"
)
return merged
return _merge_chunks_round_robin([vector_chunks, entity_chunks, relation_chunks])
@ -3267,8 +3294,10 @@ async def _build_llm_context(
if truncated_chunks:
sample_chunk = truncated_chunks[0]
has_pages = "pages" in sample_chunk
logger.info(f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}")
logger.info(
f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}"
)
reference_list, truncated_chunks = generate_reference_list_from_chunks(
truncated_chunks
)
@ -3323,18 +3352,18 @@ async def _build_llm_context(
text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
)
# Format reference list with page numbers if available
formatted_references = []
for ref in reference_list:
if not ref.get("reference_id"):
continue
file_path = ref['file_path']
ref_id = ref['reference_id']
file_path = ref["file_path"]
ref_id = ref["reference_id"]
# Add page numbers if available
pages = ref.get('pages')
pages = ref.get("pages")
if pages and len(pages) > 0:
if len(pages) == 1:
# Single page: "document.pdf (p. 5)"
@ -3345,19 +3374,21 @@ async def _build_llm_context(
else:
# No page info: "document.txt"
citation = f"[{ref_id}] {file_path}"
formatted_references.append(citation)
reference_list_str = "\n".join(formatted_references)
# Debug: Log what references are being sent to the LLM
logger.info(f"Reference list for LLM ({len(formatted_references)} refs):")
for ref_line in formatted_references[:3]: # Show first 3
logger.info(f" {ref_line}")
# Store valid reference IDs for validation
valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')}
global_config['_valid_reference_ids'] = valid_ref_ids
valid_ref_ids = {
ref["reference_id"] for ref in reference_list if ref.get("reference_id")
}
global_config["_valid_reference_ids"] = valid_ref_ids
result = kg_context_template.format(
entities_str=entities_str,
@ -3766,12 +3797,14 @@ async def _find_related_text_unit_from_entities(
chunk_data_copy = chunk_data.copy()
chunk_data_copy["source_type"] = "entity"
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
# Debug: Check if page metadata is present
if i == 0: # Log first chunk only
has_pages = "pages" in chunk_data_copy
logger.info(f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}")
logger.info(
f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}"
)
result_chunks.append(chunk_data_copy)
# Update chunk tracking if provided
@ -4250,18 +4283,18 @@ async def naive_query(
text_units_str = "\n".join(
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
)
# Format reference list with page numbers if available
formatted_references = []
for ref in reference_list:
if not ref.get("reference_id"):
continue
file_path = ref['file_path']
ref_id = ref['reference_id']
file_path = ref["file_path"]
ref_id = ref["reference_id"]
# Add page numbers if available
pages = ref.get('pages')
pages = ref.get("pages")
if pages and len(pages) > 0:
if len(pages) == 1:
# Single page: "document.pdf (p. 5)"
@ -4272,19 +4305,21 @@ async def naive_query(
else:
# No page info: "document.txt"
citation = f"[{ref_id}] {file_path}"
formatted_references.append(citation)
reference_list_str = "\n".join(formatted_references)
# Debug: Log what references are being sent to the LLM
logger.info(f"Reference list for LLM ({len(formatted_references)} refs):")
for ref_line in formatted_references[:3]: # Show first 3
logger.info(f" {ref_line}")
# Store valid reference IDs for validation
valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')}
global_config['_valid_reference_ids'] = valid_ref_ids
valid_ref_ids = {
ref["reference_id"] for ref in reference_list if ref.get("reference_id")
}
global_config["_valid_reference_ids"] = valid_ref_ids
naive_context_template = PROMPTS["naive_query_context"]
context_content = naive_context_template.format(
@ -4329,13 +4364,15 @@ async def naive_query(
" == LLM cache == Query cache hit, using cached response as query result"
)
response = cached_response
# Validate references in cached response too
valid_ref_ids = global_config.get('_valid_reference_ids', set())
valid_ref_ids = global_config.get("_valid_reference_ids", set())
if valid_ref_ids:
response, is_valid = validate_llm_references(response, valid_ref_ids)
if not is_valid:
logger.warning("Cached LLM response contained invalid references and has been cleaned")
logger.warning(
"Cached LLM response contained invalid references and has been cleaned"
)
else:
response = await use_model_func(
user_query,
@ -4344,13 +4381,15 @@ async def naive_query(
enable_cot=True,
stream=query_param.stream,
)
# Validate references in the response
valid_ref_ids = global_config.get('_valid_reference_ids', set())
valid_ref_ids = global_config.get("_valid_reference_ids", set())
if valid_ref_ids:
response, is_valid = validate_llm_references(response, valid_ref_ids)
if not is_valid:
logger.warning("LLM response contained invalid references and has been cleaned")
logger.warning(
"LLM response contained invalid references and has been cleaned"
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {

View file

@ -2852,56 +2852,74 @@ def convert_to_user_format(
}
def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict], list[dict]]:
def generate_reference_list_from_chunks(
chunks: list[dict],
) -> tuple[list[dict], list[dict]]:
"""Generate reference list from chunks, showing exact chunk page ranges."""
if not chunks:
return [], []
def _create_chunk_references(chunks: list[dict]) -> tuple[list[dict], dict[str, str]]:
def _create_chunk_references(
chunks: list[dict],
) -> tuple[list[dict], dict[str, str]]:
"""Create references based on actual chunk page ranges instead of file aggregation."""
chunk_ref_map = {} # Maps (file_path, page_range) -> reference_id
references = []
ref_id_counter = 1
for chunk in chunks:
file_path = chunk.get("file_path", "")
if file_path == "unknown_source":
continue
# Get page data for this specific chunk
chunk_pages = chunk.get("pages")
if chunk_pages and isinstance(chunk_pages, list):
# Create a unique key for this file + page range combination
page_range_key = (file_path, tuple(sorted(chunk_pages)))
if page_range_key not in chunk_ref_map:
# Create new reference for this file + page range
chunk_ref_map[page_range_key] = str(ref_id_counter)
# Build page range display
sorted_pages = sorted(chunk_pages)
if len(sorted_pages) == 1:
page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[0]}
page_display = {
"pages": sorted_pages,
"start_page": sorted_pages[0],
"end_page": sorted_pages[0],
}
else:
page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[-1]}
references.append({
"reference_id": str(ref_id_counter),
"file_path": file_path,
**page_display
})
ref_id_counter += 1
return references, {f"{file_path}_{'-'.join(map(str, pages))}": ref_id
for (file_path, pages), ref_id in chunk_ref_map.items()}
page_display = {
"pages": sorted_pages,
"start_page": sorted_pages[0],
"end_page": sorted_pages[-1],
}
def _add_reference_ids_to_chunks(chunks: list[dict], chunk_ref_map: dict[str, str]) -> list[dict]:
references.append(
{
"reference_id": str(ref_id_counter),
"file_path": file_path,
**page_display,
}
)
ref_id_counter += 1
return references, {
f"{file_path}_{'-'.join(map(str, pages))}": ref_id
for (file_path, pages), ref_id in chunk_ref_map.items()
}
def _add_reference_ids_to_chunks(
chunks: list[dict], chunk_ref_map: dict[str, str]
) -> list[dict]:
"""Add reference_id field to chunks based on their specific page ranges."""
updated = []
for chunk in chunks:
chunk_copy = chunk.copy()
file_path = chunk_copy.get("file_path", "")
if file_path != "unknown_source":
chunk_pages = chunk_copy.get("pages")
if chunk_pages and isinstance(chunk_pages, list):
@ -2913,12 +2931,12 @@ def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict],
chunk_copy["reference_id"] = ""
else:
chunk_copy["reference_id"] = ""
updated.append(chunk_copy)
return updated
# Main execution flow
reference_list, chunk_ref_map = _create_chunk_references(chunks)
updated_chunks = _add_reference_ids_to_chunks(chunks, chunk_ref_map)
return reference_list, updated_chunks