code formatting
This commit is contained in:
parent
6872f085d1
commit
7864a75bda
5 changed files with 260 additions and 160 deletions
|
|
@ -864,7 +864,7 @@ async def _extract_pdf_with_docling(file_path: Path) -> str:
|
|||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
return result.document.export_to_markdown()
|
||||
|
|
@ -876,47 +876,51 @@ async def _extract_pdf_with_pypdf2(file_bytes: bytes) -> tuple[str, list[dict]]:
|
|||
pm.install("pypdf2")
|
||||
from PyPDF2 import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
pdf_file = BytesIO(file_bytes)
|
||||
reader = PdfReader(pdf_file)
|
||||
content = ""
|
||||
page_data = []
|
||||
char_position = 0
|
||||
|
||||
|
||||
for page_num, page in enumerate(reader.pages, start=1):
|
||||
page_text = page.extract_text() + "\n"
|
||||
page_start = char_position
|
||||
page_end = char_position + len(page_text)
|
||||
|
||||
page_data.append({
|
||||
"page_number": page_num,
|
||||
"content": page_text,
|
||||
"char_start": page_start,
|
||||
"char_end": page_end,
|
||||
})
|
||||
|
||||
|
||||
page_data.append(
|
||||
{
|
||||
"page_number": page_num,
|
||||
"content": page_text,
|
||||
"char_start": page_start,
|
||||
"char_end": page_end,
|
||||
}
|
||||
)
|
||||
|
||||
content += page_text
|
||||
char_position = page_end
|
||||
|
||||
|
||||
return content, page_data
|
||||
|
||||
|
||||
async def _handle_file_processing_error(
|
||||
rag: LightRAG,
|
||||
filename: str,
|
||||
error_type: str,
|
||||
error_msg: str,
|
||||
file_size: int,
|
||||
track_id: str
|
||||
rag: LightRAG,
|
||||
filename: str,
|
||||
error_type: str,
|
||||
error_msg: str,
|
||||
file_size: int,
|
||||
track_id: str,
|
||||
) -> None:
|
||||
"""Handle file processing errors consistently."""
|
||||
error_files = [{
|
||||
"file_path": filename,
|
||||
"error_description": f"[File Extraction]{error_type}",
|
||||
"original_error": error_msg,
|
||||
"file_size": file_size,
|
||||
}]
|
||||
|
||||
error_files = [
|
||||
{
|
||||
"file_path": filename,
|
||||
"error_description": f"[File Extraction]{error_type}",
|
||||
"original_error": error_msg,
|
||||
"file_size": file_size,
|
||||
}
|
||||
]
|
||||
|
||||
await rag.apipeline_enqueue_error_documents(error_files, track_id)
|
||||
logger.error(f"[File Extraction]{error_type} for {filename}: {error_msg}")
|
||||
|
||||
|
|
@ -1100,7 +1104,12 @@ async def pipeline_enqueue_file(
|
|||
content, page_data = await _extract_pdf_with_pypdf2(file)
|
||||
except Exception as e:
|
||||
await _handle_file_processing_error(
|
||||
rag, file_path.name, "PDF processing error", str(e), file_size, track_id
|
||||
rag,
|
||||
file_path.name,
|
||||
"PDF processing error",
|
||||
str(e),
|
||||
file_size,
|
||||
track_id,
|
||||
)
|
||||
return False, track_id
|
||||
|
||||
|
|
@ -1280,16 +1289,27 @@ async def pipeline_enqueue_file(
|
|||
|
||||
try:
|
||||
# Pass page_data if it was collected (only for PDFs with PyPDF2)
|
||||
page_data_to_pass = [page_data] if page_data is not None and len(page_data) > 0 else None
|
||||
|
||||
page_data_to_pass = (
|
||||
[page_data]
|
||||
if page_data is not None and len(page_data) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
# Debug logging
|
||||
if page_data_to_pass:
|
||||
logger.info(f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages")
|
||||
logger.info(
|
||||
f"Passing page metadata for {file_path.name}: {len(page_data_to_pass[0])} pages"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No page metadata for {file_path.name} (non-PDF or extraction failed)")
|
||||
|
||||
logger.debug(
|
||||
f"No page metadata for {file_path.name} (non-PDF or extraction failed)"
|
||||
)
|
||||
|
||||
await rag.apipeline_enqueue_documents(
|
||||
content, file_paths=file_path.name, track_id=track_id, page_data_list=page_data_to_pass
|
||||
content,
|
||||
file_paths=file_path.name,
|
||||
track_id=track_id,
|
||||
page_data_list=page_data_to_pass,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -1784,7 +1784,9 @@ class PGKVStorage(BaseKVStorage):
|
|||
"llm_cache_list": json.dumps(v.get("llm_cache_list", [])),
|
||||
"start_page": v.get("start_page"), # Optional page fields
|
||||
"end_page": v.get("end_page"),
|
||||
"pages": json.dumps(v.get("pages")) if v.get("pages") is not None else None,
|
||||
"pages": json.dumps(v.get("pages"))
|
||||
if v.get("pages") is not None
|
||||
else None,
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
|
|
@ -1797,7 +1799,9 @@ class PGKVStorage(BaseKVStorage):
|
|||
"content": v["content"],
|
||||
"doc_name": v.get("file_path", ""), # Map file_path to doc_name
|
||||
"workspace": self.workspace,
|
||||
"page_data": json.dumps(v.get("page_data")) if v.get("page_data") is not None else None,
|
||||
"page_data": json.dumps(v.get("page_data"))
|
||||
if v.get("page_data") is not None
|
||||
else None,
|
||||
}
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
|
|
@ -1955,7 +1959,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
"file_path": item["file_path"],
|
||||
"start_page": item.get("start_page"), # Optional page fields
|
||||
"end_page": item.get("end_page"),
|
||||
"pages": json.dumps(item.get("pages")) if item.get("pages") is not None else None,
|
||||
"pages": json.dumps(item.get("pages"))
|
||||
if item.get("pages") is not None
|
||||
else None,
|
||||
"create_time": current_time,
|
||||
"update_time": current_time,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -550,7 +550,14 @@ class LightRAG:
|
|||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
workspace=self.workspace,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"full_doc_id", "content", "file_path", "start_page", "end_page", "pages"},
|
||||
meta_fields={
|
||||
"full_doc_id",
|
||||
"content",
|
||||
"file_path",
|
||||
"start_page",
|
||||
"end_page",
|
||||
"pages",
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize document status storage
|
||||
|
|
@ -1053,7 +1060,7 @@ class LightRAG:
|
|||
else:
|
||||
# If no file paths provided, use placeholder
|
||||
file_paths = ["unknown_source"] * len(input)
|
||||
|
||||
|
||||
# Handle page_data_list
|
||||
if page_data_list is not None:
|
||||
if len(page_data_list) != len(input):
|
||||
|
|
@ -1076,14 +1083,20 @@ class LightRAG:
|
|||
|
||||
# Generate contents dict and remove duplicates in one pass
|
||||
unique_contents = {}
|
||||
for id_, doc, path, page_data in zip(ids, input, file_paths, page_data_list):
|
||||
for id_, doc, path, page_data in zip(
|
||||
ids, input, file_paths, page_data_list
|
||||
):
|
||||
cleaned_content = sanitize_text_for_encoding(doc)
|
||||
if cleaned_content not in unique_contents:
|
||||
unique_contents[cleaned_content] = (id_, path, page_data)
|
||||
|
||||
# Reconstruct contents with unique content
|
||||
contents = {
|
||||
id_: {"content": content, "file_path": file_path, "page_data": page_data}
|
||||
id_: {
|
||||
"content": content,
|
||||
"file_path": file_path,
|
||||
"page_data": page_data,
|
||||
}
|
||||
for content, (id_, file_path, page_data) in unique_contents.items()
|
||||
}
|
||||
else:
|
||||
|
|
@ -1156,7 +1169,9 @@ class LightRAG:
|
|||
doc_id: {
|
||||
"content": contents[doc_id]["content"],
|
||||
"file_path": contents[doc_id]["file_path"],
|
||||
"page_data": contents[doc_id].get("page_data"), # Optional page metadata
|
||||
"page_data": contents[doc_id].get(
|
||||
"page_data"
|
||||
), # Optional page metadata
|
||||
}
|
||||
for doc_id in new_docs.keys()
|
||||
}
|
||||
|
|
@ -1540,7 +1555,9 @@ class LightRAG:
|
|||
f"Document content not found in full_docs for doc_id: {doc_id}"
|
||||
)
|
||||
content = content_data["content"]
|
||||
page_data = content_data.get("page_data") # Optional page metadata
|
||||
page_data = content_data.get(
|
||||
"page_data"
|
||||
) # Optional page metadata
|
||||
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
|
|
|
|||
|
|
@ -66,32 +66,34 @@ load_dotenv(dotenv_path=".env", override=False)
|
|||
def validate_llm_references(response: str, valid_ref_ids: set[str]) -> tuple[str, bool]:
|
||||
"""
|
||||
Validate that LLM response only uses valid reference IDs.
|
||||
|
||||
|
||||
Args:
|
||||
response: The LLM response text
|
||||
valid_ref_ids: Set of valid reference IDs from the reference list
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (cleaned_response, is_valid)
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
# Find all reference patterns like [1], [2], etc.
|
||||
ref_pattern = r'\[(\d+)\]'
|
||||
ref_pattern = r"\[(\d+)\]"
|
||||
matches = re.findall(ref_pattern, response)
|
||||
|
||||
|
||||
invalid_refs = []
|
||||
for ref_id in matches:
|
||||
if ref_id not in valid_ref_ids:
|
||||
invalid_refs.append(ref_id)
|
||||
|
||||
|
||||
if invalid_refs:
|
||||
logger.warning(f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}")
|
||||
logger.warning(
|
||||
f"LLM generated invalid references: {invalid_refs}. Valid refs: {sorted(valid_ref_ids)}"
|
||||
)
|
||||
# Remove invalid references from the response
|
||||
for invalid_ref in invalid_refs:
|
||||
response = re.sub(rf'\[{invalid_ref}\](?:\s*\([^)]*\))?', '', response)
|
||||
response = re.sub(rf"\[{invalid_ref}\](?:\s*\([^)]*\))?", "", response)
|
||||
return response, False
|
||||
|
||||
|
||||
return response, True
|
||||
|
||||
|
||||
|
|
@ -105,99 +107,118 @@ def chunking_by_token_size(
|
|||
page_data: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Chunk content by token size with optional page tracking."""
|
||||
|
||||
def _calculate_page_range(start_char: int, end_char: int) -> tuple[int | None, int | None, list[int] | None]:
|
||||
|
||||
def _calculate_page_range(
|
||||
start_char: int, end_char: int
|
||||
) -> tuple[int | None, int | None, list[int] | None]:
|
||||
if not page_data:
|
||||
return None, None, None
|
||||
|
||||
|
||||
pages = set()
|
||||
start_page = end_page = None
|
||||
|
||||
|
||||
for page in page_data:
|
||||
page_num = page["page_number"]
|
||||
page_start = page["char_start"]
|
||||
page_end = page["char_end"]
|
||||
|
||||
|
||||
if start_char < page_end and end_char > page_start:
|
||||
pages.add(page_num)
|
||||
start_page = min(start_page, page_num) if start_page else page_num
|
||||
end_page = max(end_page, page_num) if end_page else page_num
|
||||
|
||||
|
||||
return start_page, end_page, sorted(pages) if pages else None
|
||||
|
||||
def _estimate_char_positions(token_start: int, token_end: int, total_tokens: int, total_chars: int) -> tuple[int, int]:
|
||||
|
||||
def _estimate_char_positions(
|
||||
token_start: int, token_end: int, total_tokens: int, total_chars: int
|
||||
) -> tuple[int, int]:
|
||||
if total_tokens == 0:
|
||||
return 0, total_chars
|
||||
start_char = int((token_start / total_tokens) * total_chars)
|
||||
end_char = int((token_end / total_tokens) * total_chars)
|
||||
return start_char, end_char
|
||||
|
||||
def _create_chunk_dict(token_count: int, content: str, index: int, start_char: int, end_char: int) -> dict[str, Any]:
|
||||
|
||||
def _create_chunk_dict(
|
||||
token_count: int, content: str, index: int, start_char: int, end_char: int
|
||||
) -> dict[str, Any]:
|
||||
chunk = {
|
||||
"tokens": token_count,
|
||||
"content": content.strip(),
|
||||
"chunk_order_index": index,
|
||||
}
|
||||
|
||||
|
||||
if page_data:
|
||||
start_page, end_page, pages = _calculate_page_range(start_char, end_char)
|
||||
chunk.update({
|
||||
"start_page": start_page,
|
||||
"end_page": end_page,
|
||||
"pages": pages
|
||||
})
|
||||
|
||||
chunk.update(
|
||||
{"start_page": start_page, "end_page": end_page, "pages": pages}
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
tokens = tokenizer.encode(content)
|
||||
total_tokens = len(tokens)
|
||||
total_chars = len(content)
|
||||
results = []
|
||||
|
||||
|
||||
if split_by_character:
|
||||
raw_chunks = content.split(split_by_character)
|
||||
chunks_with_positions = []
|
||||
char_pos = 0
|
||||
|
||||
|
||||
for chunk_text in raw_chunks:
|
||||
chunk_tokens = tokenizer.encode(chunk_text)
|
||||
chunk_start = char_pos
|
||||
chunk_end = char_pos + len(chunk_text)
|
||||
|
||||
|
||||
if split_by_character_only or len(chunk_tokens) <= max_token_size:
|
||||
chunks_with_positions.append((len(chunk_tokens), chunk_text, chunk_start, chunk_end))
|
||||
chunks_with_positions.append(
|
||||
(len(chunk_tokens), chunk_text, chunk_start, chunk_end)
|
||||
)
|
||||
else:
|
||||
# Split large chunks by tokens
|
||||
for token_start in range(0, len(chunk_tokens), max_token_size - overlap_token_size):
|
||||
for token_start in range(
|
||||
0, len(chunk_tokens), max_token_size - overlap_token_size
|
||||
):
|
||||
token_end = min(token_start + max_token_size, len(chunk_tokens))
|
||||
chunk_content = tokenizer.decode(chunk_tokens[token_start:token_end])
|
||||
|
||||
chunk_content = tokenizer.decode(
|
||||
chunk_tokens[token_start:token_end]
|
||||
)
|
||||
|
||||
# Estimate character positions within the chunk
|
||||
ratio_start = token_start / len(chunk_tokens)
|
||||
ratio_end = token_end / len(chunk_tokens)
|
||||
sub_start = chunk_start + int(len(chunk_text) * ratio_start)
|
||||
sub_end = chunk_start + int(len(chunk_text) * ratio_end)
|
||||
|
||||
chunks_with_positions.append((
|
||||
token_end - token_start,
|
||||
chunk_content,
|
||||
sub_start,
|
||||
sub_end
|
||||
))
|
||||
|
||||
|
||||
chunks_with_positions.append(
|
||||
(token_end - token_start, chunk_content, sub_start, sub_end)
|
||||
)
|
||||
|
||||
char_pos = chunk_end + len(split_by_character)
|
||||
|
||||
for index, (token_count, chunk_text, start_char, end_char) in enumerate(chunks_with_positions):
|
||||
results.append(_create_chunk_dict(token_count, chunk_text, index, start_char, end_char))
|
||||
|
||||
for index, (token_count, chunk_text, start_char, end_char) in enumerate(
|
||||
chunks_with_positions
|
||||
):
|
||||
results.append(
|
||||
_create_chunk_dict(token_count, chunk_text, index, start_char, end_char)
|
||||
)
|
||||
else:
|
||||
# Token-based chunking
|
||||
for index, token_start in enumerate(range(0, total_tokens, max_token_size - overlap_token_size)):
|
||||
for index, token_start in enumerate(
|
||||
range(0, total_tokens, max_token_size - overlap_token_size)
|
||||
):
|
||||
token_end = min(token_start + max_token_size, total_tokens)
|
||||
chunk_content = tokenizer.decode(tokens[token_start:token_end])
|
||||
start_char, end_char = _estimate_char_positions(token_start, token_end, total_tokens, total_chars)
|
||||
|
||||
results.append(_create_chunk_dict(token_end - token_start, chunk_content, index, start_char, end_char))
|
||||
|
||||
start_char, end_char = _estimate_char_positions(
|
||||
token_start, token_end, total_tokens, total_chars
|
||||
)
|
||||
|
||||
results.append(
|
||||
_create_chunk_dict(
|
||||
token_end - token_start, chunk_content, index, start_char, end_char
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -2467,13 +2488,15 @@ async def kg_query(
|
|||
" == LLM cache == Query cache hit, using cached response as query result"
|
||||
)
|
||||
response = cached_response
|
||||
|
||||
|
||||
# Validate references in cached response too
|
||||
valid_ref_ids = global_config.get('_valid_reference_ids', set())
|
||||
valid_ref_ids = global_config.get("_valid_reference_ids", set())
|
||||
if valid_ref_ids:
|
||||
response, is_valid = validate_llm_references(response, valid_ref_ids)
|
||||
if not is_valid:
|
||||
logger.warning("Cached LLM response contained invalid references and has been cleaned")
|
||||
logger.warning(
|
||||
"Cached LLM response contained invalid references and has been cleaned"
|
||||
)
|
||||
else:
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
|
|
@ -2482,13 +2505,15 @@ async def kg_query(
|
|||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
|
||||
# Validate references in the response
|
||||
valid_ref_ids = global_config.get('_valid_reference_ids', set())
|
||||
valid_ref_ids = global_config.get("_valid_reference_ids", set())
|
||||
if valid_ref_ids:
|
||||
response, is_valid = validate_llm_references(response, valid_ref_ids)
|
||||
if not is_valid:
|
||||
logger.warning("LLM response contained invalid references and has been cleaned")
|
||||
logger.warning(
|
||||
"LLM response contained invalid references and has been cleaned"
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
queryparam_dict = {
|
||||
|
|
@ -3136,12 +3161,12 @@ async def _merge_all_chunks(
|
|||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
|
||||
|
||||
# Preserve page metadata if available
|
||||
for field in ["start_page", "end_page", "pages"]:
|
||||
if chunk.get(field) is not None:
|
||||
metadata[field] = chunk.get(field)
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
def _merge_chunks_round_robin(chunk_sources: list[list[dict]]) -> list[dict]:
|
||||
|
|
@ -3150,18 +3175,20 @@ async def _merge_all_chunks(
|
|||
seen_ids = set()
|
||||
max_len = max(len(source) for source in chunk_sources)
|
||||
total_original = sum(len(source) for source in chunk_sources)
|
||||
|
||||
|
||||
for i in range(max_len):
|
||||
for source in chunk_sources:
|
||||
if i < len(source):
|
||||
chunk = source[i]
|
||||
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||
|
||||
|
||||
if chunk_id and chunk_id not in seen_ids:
|
||||
seen_ids.add(chunk_id)
|
||||
merged.append(_extract_chunk_metadata(chunk))
|
||||
|
||||
logger.info(f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})")
|
||||
|
||||
logger.info(
|
||||
f"Round-robin merged chunks: {total_original} -> {len(merged)} (deduplicated {total_original - len(merged)})"
|
||||
)
|
||||
return merged
|
||||
|
||||
return _merge_chunks_round_robin([vector_chunks, entity_chunks, relation_chunks])
|
||||
|
|
@ -3267,8 +3294,10 @@ async def _build_llm_context(
|
|||
if truncated_chunks:
|
||||
sample_chunk = truncated_chunks[0]
|
||||
has_pages = "pages" in sample_chunk
|
||||
logger.info(f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}")
|
||||
|
||||
logger.info(
|
||||
f"Before reference gen: chunks have pages={has_pages}, keys={list(sample_chunk.keys())[:12]}"
|
||||
)
|
||||
|
||||
reference_list, truncated_chunks = generate_reference_list_from_chunks(
|
||||
truncated_chunks
|
||||
)
|
||||
|
|
@ -3323,18 +3352,18 @@ async def _build_llm_context(
|
|||
text_units_str = "\n".join(
|
||||
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
|
||||
)
|
||||
|
||||
|
||||
# Format reference list with page numbers if available
|
||||
formatted_references = []
|
||||
for ref in reference_list:
|
||||
if not ref.get("reference_id"):
|
||||
continue
|
||||
|
||||
file_path = ref['file_path']
|
||||
ref_id = ref['reference_id']
|
||||
|
||||
|
||||
file_path = ref["file_path"]
|
||||
ref_id = ref["reference_id"]
|
||||
|
||||
# Add page numbers if available
|
||||
pages = ref.get('pages')
|
||||
pages = ref.get("pages")
|
||||
if pages and len(pages) > 0:
|
||||
if len(pages) == 1:
|
||||
# Single page: "document.pdf (p. 5)"
|
||||
|
|
@ -3345,19 +3374,21 @@ async def _build_llm_context(
|
|||
else:
|
||||
# No page info: "document.txt"
|
||||
citation = f"[{ref_id}] {file_path}"
|
||||
|
||||
|
||||
formatted_references.append(citation)
|
||||
|
||||
|
||||
reference_list_str = "\n".join(formatted_references)
|
||||
|
||||
|
||||
# Debug: Log what references are being sent to the LLM
|
||||
logger.info(f"Reference list for LLM ({len(formatted_references)} refs):")
|
||||
for ref_line in formatted_references[:3]: # Show first 3
|
||||
logger.info(f" {ref_line}")
|
||||
|
||||
|
||||
# Store valid reference IDs for validation
|
||||
valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')}
|
||||
global_config['_valid_reference_ids'] = valid_ref_ids
|
||||
valid_ref_ids = {
|
||||
ref["reference_id"] for ref in reference_list if ref.get("reference_id")
|
||||
}
|
||||
global_config["_valid_reference_ids"] = valid_ref_ids
|
||||
|
||||
result = kg_context_template.format(
|
||||
entities_str=entities_str,
|
||||
|
|
@ -3766,12 +3797,14 @@ async def _find_related_text_unit_from_entities(
|
|||
chunk_data_copy = chunk_data.copy()
|
||||
chunk_data_copy["source_type"] = "entity"
|
||||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||||
|
||||
|
||||
# Debug: Check if page metadata is present
|
||||
if i == 0: # Log first chunk only
|
||||
has_pages = "pages" in chunk_data_copy
|
||||
logger.info(f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}")
|
||||
|
||||
logger.info(
|
||||
f"Entity chunk has pages field: {has_pages}, keys: {list(chunk_data_copy.keys())[:10]}"
|
||||
)
|
||||
|
||||
result_chunks.append(chunk_data_copy)
|
||||
|
||||
# Update chunk tracking if provided
|
||||
|
|
@ -4250,18 +4283,18 @@ async def naive_query(
|
|||
text_units_str = "\n".join(
|
||||
json.dumps(text_unit, ensure_ascii=False) for text_unit in text_units_context
|
||||
)
|
||||
|
||||
|
||||
# Format reference list with page numbers if available
|
||||
formatted_references = []
|
||||
for ref in reference_list:
|
||||
if not ref.get("reference_id"):
|
||||
continue
|
||||
|
||||
file_path = ref['file_path']
|
||||
ref_id = ref['reference_id']
|
||||
|
||||
|
||||
file_path = ref["file_path"]
|
||||
ref_id = ref["reference_id"]
|
||||
|
||||
# Add page numbers if available
|
||||
pages = ref.get('pages')
|
||||
pages = ref.get("pages")
|
||||
if pages and len(pages) > 0:
|
||||
if len(pages) == 1:
|
||||
# Single page: "document.pdf (p. 5)"
|
||||
|
|
@ -4272,19 +4305,21 @@ async def naive_query(
|
|||
else:
|
||||
# No page info: "document.txt"
|
||||
citation = f"[{ref_id}] {file_path}"
|
||||
|
||||
|
||||
formatted_references.append(citation)
|
||||
|
||||
|
||||
reference_list_str = "\n".join(formatted_references)
|
||||
|
||||
|
||||
# Debug: Log what references are being sent to the LLM
|
||||
logger.info(f"Reference list for LLM ({len(formatted_references)} refs):")
|
||||
for ref_line in formatted_references[:3]: # Show first 3
|
||||
logger.info(f" {ref_line}")
|
||||
|
||||
|
||||
# Store valid reference IDs for validation
|
||||
valid_ref_ids = {ref['reference_id'] for ref in reference_list if ref.get('reference_id')}
|
||||
global_config['_valid_reference_ids'] = valid_ref_ids
|
||||
valid_ref_ids = {
|
||||
ref["reference_id"] for ref in reference_list if ref.get("reference_id")
|
||||
}
|
||||
global_config["_valid_reference_ids"] = valid_ref_ids
|
||||
|
||||
naive_context_template = PROMPTS["naive_query_context"]
|
||||
context_content = naive_context_template.format(
|
||||
|
|
@ -4329,13 +4364,15 @@ async def naive_query(
|
|||
" == LLM cache == Query cache hit, using cached response as query result"
|
||||
)
|
||||
response = cached_response
|
||||
|
||||
|
||||
# Validate references in cached response too
|
||||
valid_ref_ids = global_config.get('_valid_reference_ids', set())
|
||||
valid_ref_ids = global_config.get("_valid_reference_ids", set())
|
||||
if valid_ref_ids:
|
||||
response, is_valid = validate_llm_references(response, valid_ref_ids)
|
||||
if not is_valid:
|
||||
logger.warning("Cached LLM response contained invalid references and has been cleaned")
|
||||
logger.warning(
|
||||
"Cached LLM response contained invalid references and has been cleaned"
|
||||
)
|
||||
else:
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
|
|
@ -4344,13 +4381,15 @@ async def naive_query(
|
|||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
|
||||
# Validate references in the response
|
||||
valid_ref_ids = global_config.get('_valid_reference_ids', set())
|
||||
valid_ref_ids = global_config.get("_valid_reference_ids", set())
|
||||
if valid_ref_ids:
|
||||
response, is_valid = validate_llm_references(response, valid_ref_ids)
|
||||
if not is_valid:
|
||||
logger.warning("LLM response contained invalid references and has been cleaned")
|
||||
logger.warning(
|
||||
"LLM response contained invalid references and has been cleaned"
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
queryparam_dict = {
|
||||
|
|
|
|||
|
|
@ -2852,56 +2852,74 @@ def convert_to_user_format(
|
|||
}
|
||||
|
||||
|
||||
def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict], list[dict]]:
|
||||
def generate_reference_list_from_chunks(
|
||||
chunks: list[dict],
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Generate reference list from chunks, showing exact chunk page ranges."""
|
||||
if not chunks:
|
||||
return [], []
|
||||
|
||||
def _create_chunk_references(chunks: list[dict]) -> tuple[list[dict], dict[str, str]]:
|
||||
def _create_chunk_references(
|
||||
chunks: list[dict],
|
||||
) -> tuple[list[dict], dict[str, str]]:
|
||||
"""Create references based on actual chunk page ranges instead of file aggregation."""
|
||||
chunk_ref_map = {} # Maps (file_path, page_range) -> reference_id
|
||||
references = []
|
||||
ref_id_counter = 1
|
||||
|
||||
|
||||
for chunk in chunks:
|
||||
file_path = chunk.get("file_path", "")
|
||||
if file_path == "unknown_source":
|
||||
continue
|
||||
|
||||
|
||||
# Get page data for this specific chunk
|
||||
chunk_pages = chunk.get("pages")
|
||||
if chunk_pages and isinstance(chunk_pages, list):
|
||||
# Create a unique key for this file + page range combination
|
||||
page_range_key = (file_path, tuple(sorted(chunk_pages)))
|
||||
|
||||
|
||||
if page_range_key not in chunk_ref_map:
|
||||
# Create new reference for this file + page range
|
||||
chunk_ref_map[page_range_key] = str(ref_id_counter)
|
||||
|
||||
|
||||
# Build page range display
|
||||
sorted_pages = sorted(chunk_pages)
|
||||
if len(sorted_pages) == 1:
|
||||
page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[0]}
|
||||
page_display = {
|
||||
"pages": sorted_pages,
|
||||
"start_page": sorted_pages[0],
|
||||
"end_page": sorted_pages[0],
|
||||
}
|
||||
else:
|
||||
page_display = {"pages": sorted_pages, "start_page": sorted_pages[0], "end_page": sorted_pages[-1]}
|
||||
|
||||
references.append({
|
||||
"reference_id": str(ref_id_counter),
|
||||
"file_path": file_path,
|
||||
**page_display
|
||||
})
|
||||
ref_id_counter += 1
|
||||
|
||||
return references, {f"{file_path}_{'-'.join(map(str, pages))}": ref_id
|
||||
for (file_path, pages), ref_id in chunk_ref_map.items()}
|
||||
page_display = {
|
||||
"pages": sorted_pages,
|
||||
"start_page": sorted_pages[0],
|
||||
"end_page": sorted_pages[-1],
|
||||
}
|
||||
|
||||
def _add_reference_ids_to_chunks(chunks: list[dict], chunk_ref_map: dict[str, str]) -> list[dict]:
|
||||
references.append(
|
||||
{
|
||||
"reference_id": str(ref_id_counter),
|
||||
"file_path": file_path,
|
||||
**page_display,
|
||||
}
|
||||
)
|
||||
ref_id_counter += 1
|
||||
|
||||
return references, {
|
||||
f"{file_path}_{'-'.join(map(str, pages))}": ref_id
|
||||
for (file_path, pages), ref_id in chunk_ref_map.items()
|
||||
}
|
||||
|
||||
def _add_reference_ids_to_chunks(
|
||||
chunks: list[dict], chunk_ref_map: dict[str, str]
|
||||
) -> list[dict]:
|
||||
"""Add reference_id field to chunks based on their specific page ranges."""
|
||||
updated = []
|
||||
for chunk in chunks:
|
||||
chunk_copy = chunk.copy()
|
||||
file_path = chunk_copy.get("file_path", "")
|
||||
|
||||
|
||||
if file_path != "unknown_source":
|
||||
chunk_pages = chunk_copy.get("pages")
|
||||
if chunk_pages and isinstance(chunk_pages, list):
|
||||
|
|
@ -2913,12 +2931,12 @@ def generate_reference_list_from_chunks(chunks: list[dict]) -> tuple[list[dict],
|
|||
chunk_copy["reference_id"] = ""
|
||||
else:
|
||||
chunk_copy["reference_id"] = ""
|
||||
|
||||
|
||||
updated.append(chunk_copy)
|
||||
return updated
|
||||
|
||||
# Main execution flow
|
||||
reference_list, chunk_ref_map = _create_chunk_references(chunks)
|
||||
updated_chunks = _add_reference_ids_to_chunks(chunks, chunk_ref_map)
|
||||
|
||||
|
||||
return reference_list, updated_chunks
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue