diff --git a/lightrag/operate.py b/lightrag/operate.py
index fb14277f..faab7f26 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -1,317 +1,83 @@
from __future__ import annotations
-
-import asyncio
-import hashlib
-import json
-import os
-import re
-import time
-from collections import Counter, defaultdict
-from collections.abc import AsyncIterator, Awaitable, Callable
from functools import partial
from pathlib import Path
-from typing import Any, cast
+import asyncio
+import json
import json_repair
-from dotenv import load_dotenv
+from typing import Any, AsyncIterator, overload, Literal
+from collections import Counter, defaultdict
+from lightrag.exceptions import (
+ PipelineCancelledException,
+ ChunkTokenLimitExceededError,
+)
+from lightrag.utils import (
+ logger,
+ compute_mdhash_id,
+ Tokenizer,
+ is_float_regex,
+ sanitize_and_normalize_extracted_text,
+ pack_user_ass_to_openai_messages,
+ split_string_by_multi_markers,
+ truncate_list_by_token_size,
+ compute_args_hash,
+ handle_cache,
+ save_to_cache,
+ CacheData,
+ use_llm_func_with_cache,
+ update_chunk_cache_list,
+ remove_think_tags,
+ pick_by_weighted_polling,
+ pick_by_vector_similarity,
+ process_chunks_unified,
+ safe_vdb_operation_with_exception,
+ create_prefixed_exception,
+ fix_tuple_delimiter_corruption,
+ convert_to_user_format,
+ generate_reference_list_from_chunks,
+ apply_source_ids_limit,
+ merge_source_ids,
+ make_relation_chunk_key,
+)
from lightrag.base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
- QueryContextResult,
+ TextChunkSchema,
QueryParam,
QueryResult,
- TextChunkSchema,
+ QueryContextResult,
)
+from lightrag.prompt import PROMPTS
from lightrag.constants import (
- DEFAULT_ENTITY_NAME_MAX_LENGTH,
- DEFAULT_ENTITY_TYPES,
- DEFAULT_FILE_PATH_MORE_PLACEHOLDER,
- DEFAULT_KG_CHUNK_PICK_METHOD,
+ GRAPH_FIELD_SEP,
DEFAULT_MAX_ENTITY_TOKENS,
- DEFAULT_MAX_FILE_PATHS,
DEFAULT_MAX_RELATION_TOKENS,
- DEFAULT_MAX_SOURCE_IDS_PER_ENTITY,
- DEFAULT_MAX_SOURCE_IDS_PER_RELATION,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_RELATED_CHUNK_NUMBER,
- DEFAULT_RETRIEVAL_MULTIPLIER,
- DEFAULT_SOURCE_IDS_LIMIT_METHOD,
+ DEFAULT_KG_CHUNK_PICK_METHOD,
+ DEFAULT_ENTITY_TYPES,
DEFAULT_SUMMARY_LANGUAGE,
- GRAPH_FIELD_SEP,
- SOURCE_IDS_LIMIT_METHOD_FIFO,
SOURCE_IDS_LIMIT_METHOD_KEEP,
-)
-from lightrag.entity_resolution import (
- EntityResolutionConfig,
- fuzzy_similarity,
- get_cached_alias,
- resolve_entity_with_vdb,
- store_alias,
-)
-from lightrag.exceptions import (
- ChunkTokenLimitExceededError,
- PipelineCancelledException,
+ SOURCE_IDS_LIMIT_METHOD_FIFO,
+ DEFAULT_FILE_PATH_MORE_PLACEHOLDER,
+ DEFAULT_MAX_FILE_PATHS,
+ DEFAULT_ENTITY_NAME_MAX_LENGTH,
)
from lightrag.kg.shared_storage import get_storage_keyed_lock
-from lightrag.prompt import PROMPTS
-from lightrag.utils import (
- CacheData,
- Tokenizer,
- apply_source_ids_limit,
- compute_args_hash,
- compute_mdhash_id,
- convert_to_user_format,
- create_prefixed_exception,
- fix_tuple_delimiter_corruption,
- generate_reference_list_from_chunks,
- handle_cache,
- is_float_regex,
- logger,
- make_relation_chunk_key,
- merge_source_ids,
- pack_user_ass_to_openai_messages,
- pick_by_vector_similarity,
- pick_by_weighted_polling,
- process_chunks_unified,
- remove_think_tags,
- safe_vdb_operation_with_exception,
- sanitize_and_normalize_extracted_text,
- save_to_cache,
- split_string_by_multi_markers,
- truncate_list_by_token_size,
- update_chunk_cache_list,
- use_llm_func_with_cache,
-)
-
-# Query embedding cache configuration (configurable via environment variables)
-QUERY_EMBEDDING_CACHE_TTL = int(os.getenv('QUERY_EMBEDDING_CACHE_TTL', '3600')) # 1 hour
-QUERY_EMBEDDING_CACHE_MAX_SIZE = int(
- os.getenv('QUERY_EMBEDDING_CACHE_MAX_SIZE', os.getenv('QUERY_EMBEDDING_CACHE_SIZE', '10000'))
-)
-
-# Redis cache configuration
-REDIS_EMBEDDING_CACHE_ENABLED = os.getenv('REDIS_EMBEDDING_CACHE', 'false').lower() == 'true'
-REDIS_URI = os.getenv('REDIS_URI', 'redis://localhost:6379')
-
-# Local in-memory cache with LRU eviction
-# Structure: {query_hash: (embedding, timestamp)}
-_query_embedding_cache: dict[str, tuple[list[float], float]] = {}
-_query_embedding_cache_locks: dict[int, asyncio.Lock] = {}
-
-
-def _get_query_embedding_cache_lock() -> asyncio.Lock:
- """Return an event-loop-local lock for the embedding cache."""
- loop = asyncio.get_running_loop()
- lock = _query_embedding_cache_locks.get(id(loop))
- if lock is None:
- lock = asyncio.Lock()
- _query_embedding_cache_locks[id(loop)] = lock
- return lock
-
-
-# Global Redis client (lazy initialized)
-_redis_client = None
-
-
-async def _get_redis_client():
- """Lazy initialize Redis client."""
- global _redis_client
- if _redis_client is None and REDIS_EMBEDDING_CACHE_ENABLED:
- try:
- import redis.asyncio as redis
-
- _redis_client = redis.from_url(REDIS_URI, decode_responses=True)
- # Test connection
- await _redis_client.ping()
- logger.info(f'Redis embedding cache connected: {REDIS_URI}')
- except ImportError:
- logger.warning('Redis package not installed. Install with: pip install redis')
- return None
- except Exception as e:
- logger.warning(f'Failed to connect to Redis: {e}. Falling back to local cache.')
- return None
- return _redis_client
-
-
-async def get_cached_query_embedding(query: str, embedding_func) -> list[float] | None:
- """Get query embedding with caching to avoid redundant API calls.
-
- Supports both local in-memory cache and Redis for cross-worker sharing.
- Redis is used when REDIS_EMBEDDING_CACHE=true environment variable is set.
-
- Args:
- query: The query string to embed
- embedding_func: The embedding function to call on cache miss
-
- Returns:
- The embedding vector, or None if embedding fails
- """
- query_hash = hashlib.sha256(query.encode()).hexdigest()[:16]
- current_time = time.time()
- redis_key = f'lightrag:emb:{query_hash}'
-
- # Try Redis cache first (if enabled)
- if REDIS_EMBEDDING_CACHE_ENABLED:
- try:
- redis_client = await _get_redis_client()
- if redis_client:
- cached_json = await redis_client.get(redis_key)
- if cached_json:
- embedding = json.loads(cached_json)
- logger.debug(f'Redis embedding cache hit for hash {query_hash[:8]}')
- # Also update local cache
- _query_embedding_cache[query_hash] = (embedding, current_time)
- return embedding
- except Exception as e:
- logger.debug(f'Redis cache read error: {e}')
-
- # Check local cache
- cached = _query_embedding_cache.get(query_hash)
- if cached and (current_time - cached[1]) < QUERY_EMBEDDING_CACHE_TTL:
- logger.debug(f'Local embedding cache hit for hash {query_hash[:8]}')
- return cached[0]
-
- # Cache miss - compute embedding
- try:
- embedding = await embedding_func([query])
- embedding_result = embedding[0] # Extract first from batch
-
- # Manage local cache size - LRU eviction of oldest entries
- async with _get_query_embedding_cache_lock():
- if len(_query_embedding_cache) >= QUERY_EMBEDDING_CACHE_MAX_SIZE:
- # Remove oldest 10% of entries
- sorted_entries = sorted(_query_embedding_cache.items(), key=lambda x: x[1][1])
- for old_key, _ in sorted_entries[: QUERY_EMBEDDING_CACHE_MAX_SIZE // 10]:
- del _query_embedding_cache[old_key]
-
- # Store in local cache
- _query_embedding_cache[query_hash] = (embedding_result, current_time)
-
- # Store in Redis (if enabled)
- if REDIS_EMBEDDING_CACHE_ENABLED:
- try:
- redis_client = await _get_redis_client()
- if redis_client:
- await redis_client.setex(
- redis_key,
- QUERY_EMBEDDING_CACHE_TTL,
- json.dumps(embedding_result),
- )
- logger.debug(f'Embedding cached in Redis for hash {query_hash[:8]}')
- except Exception as e:
- logger.debug(f'Redis cache write error: {e}')
-
- logger.debug(f'Query embedding computed and cached for hash {query_hash[:8]}')
- return embedding_result
- except Exception as e:
- logger.warning(f'Failed to compute query embedding: {e}')
- return None
-
+import time
+from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
-load_dotenv(dotenv_path=Path(__file__).resolve().parent / '.env', override=False)
+load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=False)
-class TimingContext:
- """Context manager for collecting timing information during query execution.
-
- Provides both synchronous and context manager interfaces for measuring
- execution time of different query stages.
-
- Example:
- timing = TimingContext()
- timing.start('keyword_extraction')
- # ... do work ...
- timing.stop('keyword_extraction')
-
- # Or with context manager:
- with timing.measure('vector_search'):
- # ... do work ...
-
- print(timing.timings) # {'keyword_extraction': 123.4, 'vector_search': 567.8}
- """
-
- def __init__(self) -> None:
- """Initialize timing context."""
- self.timings: dict[str, float] = {}
- self._start_times: dict[str, float] = {}
- self._total_start: float = time.perf_counter()
-
- def start(self, stage: str) -> None:
- """Start timing a stage.
-
- Args:
- stage: Name of the stage to time
- """
- self._start_times[stage] = time.perf_counter()
-
- def stop(self, stage: str) -> float:
- """Stop timing a stage and record the duration.
-
- Args:
- stage: Name of the stage to stop
-
- Returns:
- Duration in milliseconds
- """
- if stage in self._start_times:
- elapsed = (time.perf_counter() - self._start_times[stage]) * 1000
- self.timings[stage] = elapsed
- del self._start_times[stage]
- return elapsed
- return 0.0
-
- def measure(self, stage: str):
- """Context manager for timing a stage.
-
- Args:
- stage: Name of the stage to time
-
- Yields:
- None
- """
- return _TimingContextManager(self, stage)
-
- def get_total_ms(self) -> float:
- """Get total elapsed time since context creation.
-
- Returns:
- Total time in milliseconds
- """
- return (time.perf_counter() - self._total_start) * 1000
-
- def finalize(self) -> dict[str, float]:
- """Finalize timing and return all recorded durations.
-
- Sets the 'total' timing to elapsed time since context creation.
-
- Returns:
- Dictionary of stage -> duration_ms
- """
- self.timings['total'] = self.get_total_ms()
- return self.timings
-
-
-class _TimingContextManager:
- """Internal context manager for TimingContext.measure()."""
-
- def __init__(self, timing_ctx: TimingContext, stage: str) -> None:
- self._timing_ctx = timing_ctx
- self._stage = stage
-
- def __enter__(self) -> None:
- self._timing_ctx.start(self._stage)
- return None
-
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
- self._timing_ctx.stop(self._stage)
-
-
-def _truncate_entity_identifier(identifier: str, limit: int, chunk_key: str, identifier_role: str) -> str:
+def _truncate_entity_identifier(
+ identifier: str, limit: int, chunk_key: str, identifier_role: str
+) -> str:
"""Truncate entity identifiers that exceed the configured length limit."""
if len(identifier) <= limit:
@@ -338,25 +104,17 @@ def chunking_by_token_size(
chunk_overlap_token_size: int = 100,
chunk_token_size: int = 1200,
) -> list[dict[str, Any]]:
- """Split content into chunks by token size, tracking character positions.
-
- Returns chunks with char_start and char_end for citation support.
- """
tokens = tokenizer.encode(content)
results: list[dict[str, Any]] = []
if split_by_character:
raw_chunks = content.split(split_by_character)
- # Track character positions: (tokens, chunk_text, char_start, char_end)
- new_chunks: list[tuple[int, str, int, int]] = []
- char_position = 0
- separator_len = len(split_by_character)
-
+ new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = tokenizer.encode(chunk)
if len(_tokens) > chunk_token_size:
logger.warning(
- 'Chunk split_by_character exceeds token limit: len=%d limit=%d',
+ "Chunk split_by_character exceeds token limit: len=%d limit=%d",
len(_tokens),
chunk_token_size,
)
@@ -365,58 +123,40 @@ def chunking_by_token_size(
chunk_token_limit=chunk_token_size,
chunk_preview=chunk[:120],
)
- chunk_start = char_position
- chunk_end = char_position + len(chunk)
- new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end))
- char_position = chunk_end + separator_len # Skip separator
+ new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
- chunk_start = char_position
_tokens = tokenizer.encode(chunk)
if len(_tokens) > chunk_token_size:
- # Sub-chunking: approximate char positions within the chunk
- sub_char_position = 0
- for start in range(0, len(_tokens), chunk_token_size - chunk_overlap_token_size):
- chunk_content = tokenizer.decode(_tokens[start : start + chunk_token_size])
- # Approximate char position based on content length ratio
- sub_start = chunk_start + sub_char_position
- sub_end = sub_start + len(chunk_content)
- new_chunks.append(
- (min(chunk_token_size, len(_tokens) - start), chunk_content, sub_start, sub_end)
+ for start in range(
+ 0, len(_tokens), chunk_token_size - chunk_overlap_token_size
+ ):
+ chunk_content = tokenizer.decode(
+ _tokens[start : start + chunk_token_size]
+ )
+ new_chunks.append(
+ (min(chunk_token_size, len(_tokens) - start), chunk_content)
)
- sub_char_position += len(chunk_content) - (chunk_overlap_token_size * 4) # Approx overlap
else:
- chunk_end = chunk_start + len(chunk)
- new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end))
- char_position = chunk_start + len(chunk) + separator_len
-
- for index, (_len, chunk, char_start, char_end) in enumerate(new_chunks):
+ new_chunks.append((len(_tokens), chunk))
+ for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
- 'tokens': _len,
- 'content': chunk.strip(),
- 'chunk_order_index': index,
- 'char_start': char_start,
- 'char_end': char_end,
+ "tokens": _len,
+ "content": chunk.strip(),
+ "chunk_order_index": index,
}
)
else:
- # Token-based chunking: track character positions through decoded content
- char_position = 0
- for index, start in enumerate(range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)):
+ for index, start in enumerate(
+ range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)
+ ):
chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size])
- # For overlapping chunks, approximate positions based on previous chunk
- char_start = 0 if index == 0 else char_position
- char_end = char_start + len(chunk_content)
- char_position = char_start + len(chunk_content) - (chunk_overlap_token_size * 4) # Approx char overlap
-
results.append(
{
- 'tokens': min(chunk_token_size, len(tokens) - start),
- 'content': chunk_content.strip(),
- 'chunk_order_index': index,
- 'char_start': char_start,
- 'char_end': char_end,
+ "tokens": min(chunk_token_size, len(tokens) - start),
+ "content": chunk_content.strip(),
+ "chunk_order_index": index,
}
)
return results
@@ -450,17 +190,17 @@ async def _handle_entity_relation_summary(
"""
# Handle empty input
if not description_list:
- return '', False
+ return "", False
# If only one description, return it directly (no need for LLM call)
if len(description_list) == 1:
return description_list[0], False
# Get configuration
- tokenizer: Tokenizer = global_config['tokenizer']
- summary_context_size = global_config['summary_context_size']
- summary_max_tokens = global_config['summary_max_tokens']
- force_llm_summary_on_merge = global_config['force_llm_summary_on_merge']
+ tokenizer: Tokenizer = global_config["tokenizer"]
+ summary_context_size = global_config["summary_context_size"]
+ summary_max_tokens = global_config["summary_max_tokens"]
+ force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
current_list = description_list[:] # Copy the list to avoid modifying original
llm_was_used = False # Track whether LLM was used during the entire process
@@ -472,13 +212,18 @@ async def _handle_entity_relation_summary(
# If total length is within limits, perform final summarization
if total_tokens <= summary_context_size or len(current_list) <= 2:
- if len(current_list) < force_llm_summary_on_merge and total_tokens < summary_max_tokens:
+ if (
+ len(current_list) < force_llm_summary_on_merge
+ and total_tokens < summary_max_tokens
+ ):
# no LLM needed, just join the descriptions
final_description = seperator.join(current_list)
- return final_description if final_description else '', llm_was_used
+ return final_description if final_description else "", llm_was_used
else:
if total_tokens > summary_context_size and len(current_list) <= 2:
- logger.warning(f'Summarizing {entity_or_relation_name}: Oversize description found')
+ logger.warning(
+ f"Summarizing {entity_or_relation_name}: Oversize descpriton found"
+ )
# Final summarization of remaining descriptions - LLM will be used
final_summary = await _summarize_descriptions(
description_type,
@@ -496,7 +241,7 @@ async def _handle_entity_relation_summary(
current_tokens = 0
# Currently least 3 descriptions in current_list
- for _i, desc in enumerate(current_list):
+ for i, desc in enumerate(current_list):
desc_tokens = len(tokenizer.encode(desc))
# If adding current description would exceed limit, finalize current chunk
@@ -506,7 +251,9 @@ async def _handle_entity_relation_summary(
# Force add one more description to ensure minimum 2 per chunk
current_chunk.append(desc)
chunks.append(current_chunk)
- logger.warning(f'Summarizing {entity_or_relation_name}: Oversize description found')
+ logger.warning(
+ f"Summarizing {entity_or_relation_name}: Oversize descpriton found"
+ )
current_chunk = [] # next group is empty
current_tokens = 0
else: # curren_chunk is ready for summary in reduce phase
@@ -522,15 +269,15 @@ async def _handle_entity_relation_summary(
chunks.append(current_chunk)
logger.info(
- f' Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups'
+ f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups"
)
- # Reduce phase: summarize each group from chunks IN PARALLEL
- async def _summarize_single_chunk(chunk: list[str]) -> tuple[str, bool]:
- """Summarize a single chunk, returning (summary, used_llm)."""
+ # Reduce phase: summarize each group from chunks
+ new_summaries = []
+ for chunk in chunks:
if len(chunk) == 1:
# Optimization: single description chunks don't need LLM summarization
- return chunk[0], False
+ new_summaries.append(chunk[0])
else:
# Multiple descriptions need LLM summarization
summary = await _summarize_descriptions(
@@ -540,16 +287,8 @@ async def _handle_entity_relation_summary(
global_config,
llm_response_cache,
)
- return summary, True
-
- # Create tasks for all chunks and run in parallel
- tasks = [asyncio.create_task(_summarize_single_chunk(chunk)) for chunk in chunks]
- results = await asyncio.gather(*tasks)
-
- # Collect results while preserving order
- new_summaries = [result[0] for result in results]
- if any(result[1] for result in results):
- llm_was_used = True # Mark that LLM was used in reduce phase
+ new_summaries.append(summary)
+ llm_was_used = True # Mark that LLM was used in reduce phase
# Update current list with new summaries for next iteration
current_list = new_summaries
@@ -573,22 +312,22 @@ async def _summarize_descriptions(
Returns:
Summarized description string
"""
- use_llm_func: Callable[..., Any] = global_config['llm_model_func']
+ use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
- language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE)
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
- summary_length_recommended = global_config['summary_length_recommended']
+ summary_length_recommended = global_config["summary_length_recommended"]
- prompt_template = PROMPTS['summarize_entity_descriptions']
+ prompt_template = PROMPTS["summarize_entity_descriptions"]
# Convert descriptions to JSONL format and apply token-based truncation
- tokenizer = global_config['tokenizer']
- summary_context_size = global_config['summary_context_size']
+ tokenizer = global_config["tokenizer"]
+ summary_context_size = global_config["summary_context_size"]
# Create list of JSON objects with "Description" field
- json_descriptions = [{'Description': desc} for desc in description_list]
+ json_descriptions = [{"Description": desc} for desc in description_list]
# Use truncate_list_by_token_size for length truncation
truncated_json_descriptions = truncate_list_by_token_size(
@@ -599,16 +338,18 @@ async def _summarize_descriptions(
)
# Convert to JSONL format (one JSON object per line)
- joined_descriptions = '\n'.join(json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions)
+ joined_descriptions = "\n".join(
+ json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions
+ )
# Prepare context for the prompt
- context_base = {
- 'description_type': description_type,
- 'description_name': description_name,
- 'description_list': joined_descriptions,
- 'summary_length': summary_length_recommended,
- 'language': language,
- }
+ context_base = dict(
+ description_type=description_type,
+ description_name=description_name,
+ description_list=joined_descriptions,
+ summary_length=summary_length_recommended,
+ language=language,
+ )
use_prompt = prompt_template.format(**context_base)
# Use LLM function with cache (higher priority for summary generation)
@@ -616,20 +357,20 @@ async def _summarize_descriptions(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
- cache_type='summary',
+ cache_type="summary",
)
# Check summary token length against embedding limit
- embedding_token_limit = global_config.get('embedding_token_limit')
+ embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
- tokenizer = global_config['tokenizer']
+ tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
- f'Summary tokens ({summary_token_count}) exceeds 90% of embedding limit '
- f'({embedding_token_limit}) for {description_type}: {description_name}'
+ f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
+ f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary
@@ -639,33 +380,43 @@ async def _handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
timestamp: int,
- file_path: str = 'unknown_source',
+ file_path: str = "unknown_source",
):
- if len(record_attributes) != 4 or 'entity' not in record_attributes[0]:
- if len(record_attributes) > 1 and 'entity' in record_attributes[0]:
+ if len(record_attributes) != 4 or "entity" not in record_attributes[0]:
+ if len(record_attributes) > 1 and "entity" in record_attributes[0]:
logger.warning(
- f'{chunk_key}: LLM output format error; found {len(record_attributes)}/4 fields on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else "N/A"}`'
+ f"{chunk_key}: LLM output format error; found {len(record_attributes)}/4 feilds on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
)
logger.debug(record_attributes)
return None
try:
- entity_name = sanitize_and_normalize_extracted_text(record_attributes[1], remove_inner_quotes=True)
+ entity_name = sanitize_and_normalize_extracted_text(
+ record_attributes[1], remove_inner_quotes=True
+ )
# Validate entity name after all cleaning steps
if not entity_name or not entity_name.strip():
- logger.info(f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'")
+ logger.info(
+ f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'"
+ )
return None
# Process entity type with same cleaning pipeline
- entity_type = sanitize_and_normalize_extracted_text(record_attributes[2], remove_inner_quotes=True)
+ entity_type = sanitize_and_normalize_extracted_text(
+ record_attributes[2], remove_inner_quotes=True
+ )
- if not entity_type.strip() or any(char in entity_type for char in ["'", '(', ')', '<', '>', '|', '/', '\\']):
- logger.warning(f'Entity extraction error: invalid entity type in: {record_attributes}')
+ if not entity_type.strip() or any(
+ char in entity_type for char in ["'", "(", ")", "<", ">", "|", "/", "\\"]
+ ):
+ logger.warning(
+ f"Entity extraction error: invalid entity type in: {record_attributes}"
+ )
return None
# Remove spaces and convert to lowercase
- entity_type = entity_type.replace(' ', '').lower()
+ entity_type = entity_type.replace(" ", "").lower()
# Process entity description with same cleaning pipeline
entity_description = sanitize_and_normalize_extracted_text(record_attributes[3])
@@ -676,20 +427,24 @@ async def _handle_single_entity_extraction(
)
return None
- return {
- 'entity_name': entity_name,
- 'entity_type': entity_type,
- 'description': entity_description,
- 'source_id': chunk_key,
- 'file_path': file_path,
- 'timestamp': timestamp,
- }
+ return dict(
+ entity_name=entity_name,
+ entity_type=entity_type,
+ description=entity_description,
+ source_id=chunk_key,
+ file_path=file_path,
+ timestamp=timestamp,
+ )
except ValueError as e:
- logger.error(f'Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}')
+ logger.error(
+ f"Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}"
+ )
return None
except Exception as e:
- logger.error(f'Entity extraction failed with unexpected error in chunk {chunk_key}: {e}')
+ logger.error(
+ f"Entity extraction failed with unexpected error in chunk {chunk_key}: {e}"
+ )
return None
@@ -697,41 +452,50 @@ async def _handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
timestamp: int,
- file_path: str = 'unknown_source',
+ file_path: str = "unknown_source",
):
if (
- len(record_attributes) != 5 or 'relation' not in record_attributes[0]
+ len(record_attributes) != 5 or "relation" not in record_attributes[0]
): # treat "relationship" and "relation" interchangeable
- if len(record_attributes) > 1 and 'relation' in record_attributes[0]:
+ if len(record_attributes) > 1 and "relation" in record_attributes[0]:
logger.warning(
- f'{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on RELATION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else "N/A"}`'
+ f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
)
logger.debug(record_attributes)
return None
try:
- source = sanitize_and_normalize_extracted_text(record_attributes[1], remove_inner_quotes=True)
- target = sanitize_and_normalize_extracted_text(record_attributes[2], remove_inner_quotes=True)
+ source = sanitize_and_normalize_extracted_text(
+ record_attributes[1], remove_inner_quotes=True
+ )
+ target = sanitize_and_normalize_extracted_text(
+ record_attributes[2], remove_inner_quotes=True
+ )
# Validate entity names after all cleaning steps
if not source:
- logger.info(f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'")
+ logger.info(
+ f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'"
+ )
return None
if not target:
- logger.info(f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'")
+ logger.info(
+ f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'"
+ )
return None
if source == target:
- logger.debug(f'Relationship source and target are the same in: {record_attributes}')
+ logger.debug(
+ f"Relationship source and target are the same in: {record_attributes}"
+ )
return None
# Process keywords with same cleaning pipeline
- edge_keywords = sanitize_and_normalize_extracted_text(record_attributes[3], remove_inner_quotes=True)
- edge_keywords = edge_keywords.replace(',', ',')
-
- # Derive a relationship label from the first keyword (fallback to description later)
- relationship_label = edge_keywords.split(',')[0].strip() if edge_keywords else ''
+ edge_keywords = sanitize_and_normalize_extracted_text(
+ record_attributes[3], remove_inner_quotes=True
+ )
+ edge_keywords = edge_keywords.replace(",", ",")
# Process relationship description with same cleaning pipeline
edge_description = sanitize_and_normalize_extracted_text(record_attributes[4])
@@ -743,24 +507,26 @@ async def _handle_single_relationship_extraction(
else 1.0
)
- return {
- 'src_id': source,
- 'tgt_id': target,
- 'weight': weight,
- 'description': edge_description,
- 'keywords': edge_keywords,
- 'relationship': relationship_label,
- 'type': relationship_label,
- 'source_id': edge_source_id,
- 'file_path': file_path,
- 'timestamp': timestamp,
- }
+ return dict(
+ src_id=source,
+ tgt_id=target,
+ weight=weight,
+ description=edge_description,
+ keywords=edge_keywords,
+ source_id=edge_source_id,
+ file_path=file_path,
+ timestamp=timestamp,
+ )
except ValueError as e:
- logger.warning(f'Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}')
+ logger.warning(
+ f"Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}"
+ )
return None
except Exception as e:
- logger.warning(f'Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}')
+ logger.warning(
+ f"Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}"
+ )
return None
@@ -772,7 +538,7 @@ async def rebuild_knowledge_from_chunks(
relationships_vdb: BaseVectorStorage,
text_chunks_storage: BaseKVStorage,
llm_response_cache: BaseKVStorage,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
pipeline_status: dict | None = None,
pipeline_status_lock=None,
entity_chunks_storage: BaseKVStorage | None = None,
@@ -808,14 +574,12 @@ async def rebuild_knowledge_from_chunks(
for chunk_ids in relationships_to_rebuild.values():
all_referenced_chunk_ids.update(chunk_ids)
- status_message = (
- f'Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)'
- )
+ status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
# Get cached extraction results for these chunks using storage
# cached_results: chunk_id -> [list of (extraction_result, create_time) from LLM cache sorted by create_time of the first extraction_result]
@@ -826,12 +590,12 @@ async def rebuild_knowledge_from_chunks(
)
if not cached_results:
- status_message = 'No cached extraction results found, cannot rebuild'
+ status_message = "No cached extraction results found, cannot rebuild"
logger.warning(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
return
# Process cached results to get entities and relationships for each chunk
@@ -850,7 +614,7 @@ async def rebuild_knowledge_from_chunks(
text_chunks_storage=text_chunks_storage,
chunk_id=chunk_id,
extraction_result=result[0],
- timestamp=int(result[1]) if result[1] is not None else int(time.time()),
+ timestamp=result[1],
)
# Merge entities and relationships from this extraction result
@@ -864,8 +628,13 @@ async def rebuild_knowledge_from_chunks(
chunk_entities[chunk_id][entity_name].extend(entity_list)
else:
# Compare description lengths and keep the better one
- existing_desc_len = len(chunk_entities[chunk_id][entity_name][0].get('description', '') or '')
- new_desc_len = len(entity_list[0].get('description', '') or '')
+ existing_desc_len = len(
+ chunk_entities[chunk_id][entity_name][0].get(
+ "description", ""
+ )
+ or ""
+ )
+ new_desc_len = len(entity_list[0].get("description", "") or "")
if new_desc_len > existing_desc_len:
# Replace with the new entity that has longer description
@@ -882,8 +651,13 @@ async def rebuild_knowledge_from_chunks(
chunk_relationships[chunk_id][rel_key].extend(rel_list)
else:
# Compare description lengths and keep the better one
- existing_desc_len = len(chunk_relationships[chunk_id][rel_key][0].get('description', '') or '')
- new_desc_len = len(rel_list[0].get('description', '') or '')
+ existing_desc_len = len(
+ chunk_relationships[chunk_id][rel_key][0].get(
+ "description", ""
+ )
+ or ""
+ )
+ new_desc_len = len(rel_list[0].get("description", "") or "")
if new_desc_len > existing_desc_len:
# Replace with the new relationship that has longer description
@@ -891,16 +665,18 @@ async def rebuild_knowledge_from_chunks(
# Otherwise keep existing version
except Exception as e:
- status_message = f'Failed to parse cached extraction result for chunk {chunk_id}: {e}'
+ status_message = (
+ f"Failed to parse cached extraction result for chunk {chunk_id}: {e}"
+ )
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
continue
# Get max async tasks limit from global_config for semaphore control
- graph_max_async = global_config.get('llm_model_max_async', 4) * 2
+ graph_max_async = global_config.get("llm_model_max_async", 4) * 2
semaphore = asyncio.Semaphore(graph_max_async)
# Counters for tracking progress
@@ -912,9 +688,11 @@ async def rebuild_knowledge_from_chunks(
async def _locked_rebuild_entity(entity_name, chunk_ids):
nonlocal rebuilt_entities_count, failed_entities_count
async with semaphore:
- workspace = global_config.get('workspace', '')
- namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB'
- async with get_storage_keyed_lock([entity_name], namespace=namespace, enable_logging=False):
+ workspace = global_config.get("workspace", "")
+ namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
+ async with get_storage_keyed_lock(
+ [entity_name], namespace=namespace, enable_logging=False
+ ):
try:
await _rebuild_single_entity(
knowledge_graph_inst=knowledge_graph_inst,
@@ -929,18 +707,18 @@ async def rebuild_knowledge_from_chunks(
rebuilt_entities_count += 1
except Exception as e:
failed_entities_count += 1
- status_message = f'Failed to rebuild `{entity_name}`: {e}'
+ status_message = f"Failed to rebuild `{entity_name}`: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
async def _locked_rebuild_relationship(src, tgt, chunk_ids):
nonlocal rebuilt_relationships_count, failed_relationships_count
async with semaphore:
- workspace = global_config.get('workspace', '')
- namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB'
+ workspace = global_config.get("workspace", "")
+ namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
# Sort src and tgt to ensure order-independent lock key generation
sorted_key_parts = sorted([src, tgt])
async with get_storage_keyed_lock(
@@ -967,12 +745,12 @@ async def rebuild_knowledge_from_chunks(
rebuilt_relationships_count += 1
except Exception as e:
failed_relationships_count += 1
- status_message = f'Failed to rebuild `{src}`~`{tgt}`: {e}'
+ status_message = f"Failed to rebuild `{src}`~`{tgt}`: {e}"
logger.info(status_message) # Per requirement, change to info
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
# Create tasks for parallel processing
tasks = []
@@ -988,12 +766,12 @@ async def rebuild_knowledge_from_chunks(
tasks.append(task)
# Log parallel processing start
- status_message = f'Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})'
+ status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
# Execute all tasks in parallel with semaphore control and early failure detection
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
@@ -1028,15 +806,15 @@ async def rebuild_knowledge_from_chunks(
raise first_exception
# Final status report
- status_message = f'KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully.'
+ status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
if failed_entities_count > 0 or failed_relationships_count > 0:
- status_message += f' Failed: {failed_entities_count} entities, {failed_relationships_count} relationships.'
+ status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships."
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
async def _get_cached_extraction_results(
@@ -1070,14 +848,14 @@ async def _get_cached_extraction_results(
chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids))
for chunk_data in chunk_data_list:
if chunk_data and isinstance(chunk_data, dict):
- llm_cache_list = chunk_data.get('llm_cache_list', [])
+ llm_cache_list = chunk_data.get("llm_cache_list", [])
if llm_cache_list:
all_cache_ids.update(llm_cache_list)
else:
- logger.warning(f'Chunk data is invalid or None: {chunk_data}')
+ logger.warning(f"Chunk data is invalid or None: {chunk_data}")
if not all_cache_ids:
- logger.warning(f'No LLM cache IDs found for {len(chunk_ids)} chunk IDs')
+ logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs")
return cached_results
# Batch get LLM cache entries
@@ -1089,12 +867,14 @@ async def _get_cached_extraction_results(
if (
cache_entry is not None
and isinstance(cache_entry, dict)
- and cache_entry.get('cache_type') == 'extract'
- and cache_entry.get('chunk_id') in chunk_ids
+ and cache_entry.get("cache_type") == "extract"
+ and cache_entry.get("chunk_id") in chunk_ids
):
- chunk_id = cache_entry['chunk_id']
- extraction_result = cache_entry['return']
- create_time = cache_entry.get('create_time', 0) # Get creation time, default to 0
+ chunk_id = cache_entry["chunk_id"]
+ extraction_result = cache_entry["return"]
+ create_time = cache_entry.get(
+ "create_time", 0
+ ) # Get creation time, default to 0
valid_entries += 1
# Support multiple LLM caches per chunk
@@ -1112,14 +892,18 @@ async def _get_cached_extraction_results(
chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1]
# Sort cached_results by the earliest create_time of each chunk
- sorted_chunk_ids = sorted(chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id])
+ sorted_chunk_ids = sorted(
+ chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id]
+ )
# Rebuild cached_results in sorted order
sorted_cached_results = {}
for chunk_id in sorted_chunk_ids:
sorted_cached_results[chunk_id] = cached_results[chunk_id]
- logger.info(f'Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results')
+ logger.info(
+ f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results"
+ )
return sorted_cached_results # each item: list(extraction_result, create_time)
@@ -1127,9 +911,9 @@ async def _process_extraction_result(
result: str,
chunk_key: str,
timestamp: int,
- file_path: str = 'unknown_source',
- tuple_delimiter: str = '<|#|>',
- completion_delimiter: str = '<|COMPLETE|>',
+ file_path: str = "unknown_source",
+ tuple_delimiter: str = "<|#|>",
+ completion_delimiter: str = "<|COMPLETE|>",
) -> tuple[dict, dict]:
"""Process a single extraction result (either initial or gleaning)
Args:
@@ -1146,95 +930,50 @@ async def _process_extraction_result(
maybe_edges = defaultdict(list)
if completion_delimiter not in result:
- logger.warning(f'{chunk_key}: Complete delimiter can not be found in extraction result')
+ logger.warning(
+ f"{chunk_key}: Complete delimiter can not be found in extraction result"
+ )
- # Split LLM output result to records by "\n" or other common separators
- # Some models use <|#|> between records instead of newlines
+ # Split LLL output result to records by "\n"
records = split_string_by_multi_markers(
result,
- ['\n', completion_delimiter, completion_delimiter.lower()],
+ ["\n", completion_delimiter, completion_delimiter.lower()],
)
- # Additional split: handle models that output records separated by tuple_delimiter + "entity" or "relation"
- # e.g., "entity<|#|>A<|#|>type<|#|>desc<|#|>entity<|#|>B<|#|>type<|#|>desc"
- expanded_records = []
- for record in records:
- record = record.strip()
- if not record:
- continue
- # Split by patterns that indicate a new entity/relation record starting
- # Pattern: <|#|>entity<|#|> or <|#|>relation<|#|>
- sub_records = split_string_by_multi_markers(
- record,
- [
- f'{tuple_delimiter}entity{tuple_delimiter}',
- f'{tuple_delimiter}relation{tuple_delimiter}',
- f'{tuple_delimiter}relationship{tuple_delimiter}',
- ],
- )
- for i, sub in enumerate(sub_records):
- sub = sub.strip()
- if not sub:
- continue
- # First sub-record: check if it already starts with entity/relation
- if i == 0:
- if sub.lower().startswith(('entity', 'relation')):
- expanded_records.append(sub)
- else:
- # Might be partial, try to recover by checking content
- # If it looks like entity fields (has enough delimiters), prefix with 'entity'
- if sub.count(tuple_delimiter) >= 2:
- expanded_records.append(f'entity{tuple_delimiter}{sub}')
- else:
- expanded_records.append(sub)
- else:
- # Subsequent sub-records lost their prefix during split, restore it
- # Determine if it's entity or relation based on field count
- # entity: name, type, desc (3 fields after split = 2 delimiters)
- # relation: source, target, keywords, desc (4 fields = 3 delimiters)
- delimiter_count = sub.count(tuple_delimiter)
- if delimiter_count >= 3:
- expanded_records.append(f'relation{tuple_delimiter}{sub}')
- else:
- expanded_records.append(f'entity{tuple_delimiter}{sub}')
-
- records = expanded_records if expanded_records else records
-
# Fix LLM output format error which use tuple_delimiter to seperate record instead of "\n"
fixed_records = []
for record in records:
record = record.strip()
- if not record:
+ if record is None:
continue
- # If record already starts with entity/relation, keep it as-is
- if record.lower().startswith(('entity', 'relation')):
- fixed_records.append(record)
- continue
- # Otherwise try to recover malformed records
- entity_records = split_string_by_multi_markers(record, [f'{tuple_delimiter}entity{tuple_delimiter}'])
+ entity_records = split_string_by_multi_markers(
+ record, [f"{tuple_delimiter}entity{tuple_delimiter}"]
+ )
for entity_record in entity_records:
- if not entity_record.startswith('entity') and not entity_record.startswith('relation'):
- entity_record = f'entity<|{entity_record}'
- entity_relation_records = split_string_by_multi_markers(
- # treat "relationship" and "relation" interchangeable
- entity_record,
- [
- f'{tuple_delimiter}relationship{tuple_delimiter}',
- f'{tuple_delimiter}relation{tuple_delimiter}',
- ],
- )
- for entity_relation_record in entity_relation_records:
- if not entity_relation_record.startswith('entity') and not entity_relation_record.startswith(
- 'relation'
- ):
- entity_relation_record = f'relation{tuple_delimiter}{entity_relation_record}'
- fixed_records.append(entity_relation_record)
- else:
- fixed_records.append(entity_record)
+ if not entity_record.startswith("entity") and not entity_record.startswith(
+ "relation"
+ ):
+ entity_record = f"entity<|{entity_record}"
+ entity_relation_records = split_string_by_multi_markers(
+ # treat "relationship" and "relation" interchangeable
+ entity_record,
+ [
+ f"{tuple_delimiter}relationship{tuple_delimiter}",
+ f"{tuple_delimiter}relation{tuple_delimiter}",
+ ],
+ )
+ for entity_relation_record in entity_relation_records:
+ if not entity_relation_record.startswith(
+ "entity"
+ ) and not entity_relation_record.startswith("relation"):
+ entity_relation_record = (
+ f"relation{tuple_delimiter}{entity_relation_record}"
+ )
+ fixed_records = fixed_records + [entity_relation_record]
if len(fixed_records) != len(records):
- logger.debug(
- f'{chunk_key}: Recovered {len(fixed_records)} records from {len(records)} raw records'
+ logger.warning(
+ f"{chunk_key}: LLM output format error; find LLM use {tuple_delimiter} as record seperators instead new-line"
)
for record in fixed_records:
@@ -1248,20 +987,24 @@ async def _process_extraction_result(
if delimiter_core != delimiter_core.lower():
# change delimiter_core to lower case, and fix again
delimiter_core = delimiter_core.lower()
- record = fix_tuple_delimiter_corruption(record, delimiter_core, tuple_delimiter)
+ record = fix_tuple_delimiter_corruption(
+ record, delimiter_core, tuple_delimiter
+ )
record_attributes = split_string_by_multi_markers(record, [tuple_delimiter])
# Try to parse as entity
- entity_data = await _handle_single_entity_extraction(record_attributes, chunk_key, timestamp, file_path)
+ entity_data = await _handle_single_entity_extraction(
+ record_attributes, chunk_key, timestamp, file_path
+ )
if entity_data is not None:
truncated_name = _truncate_entity_identifier(
- entity_data['entity_name'],
+ entity_data["entity_name"],
DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key,
- 'Entity name',
+ "Entity name",
)
- entity_data['entity_name'] = truncated_name
+ entity_data["entity_name"] = truncated_name
maybe_nodes[truncated_name].append(entity_data)
continue
@@ -1271,19 +1014,19 @@ async def _process_extraction_result(
)
if relationship_data is not None:
truncated_source = _truncate_entity_identifier(
- relationship_data['src_id'],
+ relationship_data["src_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key,
- 'Relation entity',
+ "Relation entity",
)
truncated_target = _truncate_entity_identifier(
- relationship_data['tgt_id'],
+ relationship_data["tgt_id"],
DEFAULT_ENTITY_NAME_MAX_LENGTH,
chunk_key,
- 'Relation entity',
+ "Relation entity",
)
- relationship_data['src_id'] = truncated_source
- relationship_data['tgt_id'] = truncated_target
+ relationship_data["src_id"] = truncated_source
+ relationship_data["tgt_id"] = truncated_target
maybe_edges[(truncated_source, truncated_target)].append(relationship_data)
return dict(maybe_nodes), dict(maybe_edges)
@@ -1308,7 +1051,11 @@ async def _rebuild_from_extraction_result(
# Get chunk data for file_path from storage
chunk_data = await text_chunks_storage.get_by_id(chunk_id)
- file_path = chunk_data.get('file_path', 'unknown_source') if chunk_data else 'unknown_source'
+ file_path = (
+ chunk_data.get("file_path", "unknown_source")
+ if chunk_data
+ else "unknown_source"
+ )
# Call the shared processing function
return await _process_extraction_result(
@@ -1316,8 +1063,8 @@ async def _rebuild_from_extraction_result(
chunk_id,
timestamp,
file_path,
- tuple_delimiter=PROMPTS['DEFAULT_TUPLE_DELIMITER'],
- completion_delimiter=PROMPTS['DEFAULT_COMPLETION_DELIMITER'],
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
)
@@ -1328,7 +1075,7 @@ async def _rebuild_single_entity(
chunk_ids: list[str],
chunk_entities: dict,
llm_response_cache: BaseKVStorage,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
entity_chunks_storage: BaseKVStorage | None = None,
pipeline_status: dict | None = None,
pipeline_status_lock=None,
@@ -1346,49 +1093,49 @@ async def _rebuild_single_entity(
entity_type: str,
file_paths: list[str],
source_chunk_ids: list[str],
- truncation_info: str = '',
+ truncation_info: str = "",
):
try:
# Update entity in graph storage (critical path)
updated_entity_data = {
**current_entity,
- 'description': final_description,
- 'entity_type': entity_type,
- 'source_id': GRAPH_FIELD_SEP.join(source_chunk_ids),
- 'file_path': GRAPH_FIELD_SEP.join(file_paths)
+ "description": final_description,
+ "entity_type": entity_type,
+ "source_id": GRAPH_FIELD_SEP.join(source_chunk_ids),
+ "file_path": GRAPH_FIELD_SEP.join(file_paths)
if file_paths
- else current_entity.get('file_path', 'unknown_source'),
- 'created_at': int(time.time()),
- 'truncate': truncation_info,
+ else current_entity.get("file_path", "unknown_source"),
+ "created_at": int(time.time()),
+ "truncate": truncation_info,
}
await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data)
# Update entity in vector database (equally critical)
- entity_vdb_id = compute_mdhash_id(entity_name, prefix='ent-')
- entity_content = f'{entity_name}\n{final_description}'
+ entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-")
+ entity_content = f"{entity_name}\n{final_description}"
vdb_data = {
entity_vdb_id: {
- 'content': entity_content,
- 'entity_name': entity_name,
- 'source_id': updated_entity_data['source_id'],
- 'description': final_description,
- 'entity_type': entity_type,
- 'file_path': updated_entity_data['file_path'],
+ "content": entity_content,
+ "entity_name": entity_name,
+ "source_id": updated_entity_data["source_id"],
+ "description": final_description,
+ "entity_type": entity_type,
+ "file_path": updated_entity_data["file_path"],
}
}
# Use safe operation wrapper - VDB failure must throw exception
await safe_vdb_operation_with_exception(
operation=lambda: entities_vdb.upsert(vdb_data),
- operation_name='rebuild_entity_upsert',
+ operation_name="rebuild_entity_upsert",
entity_name=entity_name,
max_retries=3,
retry_delay=0.1,
)
except Exception as e:
- error_msg = f'Failed to update entity storage for `{entity_name}`: {e}'
+ error_msg = f"Failed to update entity storage for `{entity_name}`: {e}"
logger.error(error_msg)
raise # Re-raise exception
@@ -1399,19 +1146,21 @@ async def _rebuild_single_entity(
await entity_chunks_storage.upsert(
{
entity_name: {
- 'chunk_ids': normalized_chunk_ids,
- 'count': len(normalized_chunk_ids),
+ "chunk_ids": normalized_chunk_ids,
+ "count": len(normalized_chunk_ids),
}
}
)
- limit_method = global_config.get('source_ids_limit_method') or SOURCE_IDS_LIMIT_METHOD_KEEP
+ limit_method = (
+ global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP
+ )
limited_chunk_ids = apply_source_ids_limit(
normalized_chunk_ids,
- global_config['max_source_ids_per_entity'],
+ global_config["max_source_ids_per_entity"],
limit_method,
- identifier=f'`{entity_name}`',
+ identifier=f"`{entity_name}`",
)
# Collect all entity data from relevant (limited) chunks
@@ -1421,12 +1170,14 @@ async def _rebuild_single_entity(
all_entity_data.extend(chunk_entities[chunk_id][entity_name])
if not all_entity_data:
- logger.warning(f'No entity data found for `{entity_name}`, trying to rebuild from relationships')
+ logger.warning(
+ f"No entity data found for `{entity_name}`, trying to rebuild from relationships"
+ )
# Get all edges connected to this entity
edges = await knowledge_graph_inst.get_node_edges(entity_name)
if not edges:
- logger.warning(f'No relations attached to entity `{entity_name}`')
+ logger.warning(f"No relations attached to entity `{entity_name}`")
return
# Collect relationship data to extract entity information
@@ -1437,11 +1188,11 @@ async def _rebuild_single_entity(
for src_id, tgt_id in edges:
edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id)
if edge_data:
- if edge_data.get('description'):
- relationship_descriptions.append(edge_data['description'])
+ if edge_data.get("description"):
+ relationship_descriptions.append(edge_data["description"])
- if edge_data.get('file_path'):
- edge_file_paths = edge_data['file_path'].split(GRAPH_FIELD_SEP)
+ if edge_data.get("file_path"):
+ edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP)
file_paths.update(edge_file_paths)
# deduplicate descriptions
@@ -1450,7 +1201,7 @@ async def _rebuild_single_entity(
# Generate final description from relationships or fallback to current
if description_list:
final_description, _ = await _handle_entity_relation_summary(
- 'Entity',
+ "Entity",
entity_name,
description_list,
GRAPH_FIELD_SEP,
@@ -1458,13 +1209,13 @@ async def _rebuild_single_entity(
llm_response_cache=llm_response_cache,
)
else:
- final_description = current_entity.get('description', '')
+ final_description = current_entity.get("description", "")
- entity_type = current_entity.get('entity_type', 'UNKNOWN')
+ entity_type = current_entity.get("entity_type", "UNKNOWN")
await _update_entity_storage(
final_description,
entity_type,
- list(file_paths),
+ file_paths,
limited_chunk_ids,
)
return
@@ -1476,20 +1227,22 @@ async def _rebuild_single_entity(
seen_paths = set()
for entity_data in all_entity_data:
- if entity_data.get('description'):
- descriptions.append(entity_data['description'])
- if entity_data.get('entity_type'):
- entity_types.append(entity_data['entity_type'])
- if entity_data.get('file_path'):
- file_path = entity_data['file_path']
+ if entity_data.get("description"):
+ descriptions.append(entity_data["description"])
+ if entity_data.get("entity_type"):
+ entity_types.append(entity_data["entity_type"])
+ if entity_data.get("file_path"):
+ file_path = entity_data["file_path"]
if file_path and file_path not in seen_paths:
file_paths_list.append(file_path)
seen_paths.add(file_path)
# Apply MAX_FILE_PATHS limit
- max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
- limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD)
+ max_file_paths = global_config.get("max_file_paths")
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
+ limit_method = global_config.get("source_ids_limit_method")
original_count = len(file_paths_list)
if original_count > max_file_paths:
@@ -1500,8 +1253,12 @@ async def _rebuild_single_entity(
# KEEP: keep head (earliest), discard tail
file_paths_list = file_paths_list[:max_file_paths]
- file_paths_list.append(f'...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})')
- logger.info(f'Limited `{entity_name}`: file_path {original_count} -> {max_file_paths} ({limit_method})')
+ file_paths_list.append(
+ f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})"
+ )
+ logger.info(
+ f"Limited `{entity_name}`: file_path {original_count} -> {max_file_paths} ({limit_method})"
+ )
# Remove duplicates while preserving order
description_list = list(dict.fromkeys(descriptions))
@@ -1509,13 +1266,15 @@ async def _rebuild_single_entity(
# Get most common entity type
entity_type = (
- max(set(entity_types), key=entity_types.count) if entity_types else current_entity.get('entity_type', 'UNKNOWN')
+ max(set(entity_types), key=entity_types.count)
+ if entity_types
+ else current_entity.get("entity_type", "UNKNOWN")
)
# Generate final description from entities or fallback to current
if description_list:
final_description, _ = await _handle_entity_relation_summary(
- 'Entity',
+ "Entity",
entity_name,
description_list,
GRAPH_FIELD_SEP,
@@ -1523,12 +1282,14 @@ async def _rebuild_single_entity(
llm_response_cache=llm_response_cache,
)
else:
- final_description = current_entity.get('description', '')
+ final_description = current_entity.get("description", "")
if len(limited_chunk_ids) < len(normalized_chunk_ids):
- truncation_info = f'{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}'
+ truncation_info = (
+ f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}"
+ )
else:
- truncation_info = ''
+ truncation_info = ""
await _update_entity_storage(
final_description,
@@ -1539,15 +1300,15 @@ async def _rebuild_single_entity(
)
# Log rebuild completion with truncation info
- status_message = f'Rebuild `{entity_name}` from {len(chunk_ids)} chunks'
+ status_message = f"Rebuild `{entity_name}` from {len(chunk_ids)} chunks"
if truncation_info:
- status_message += f' ({truncation_info})'
+ status_message += f" ({truncation_info})"
logger.info(status_message)
# Update pipeline status
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
async def _rebuild_single_relationship(
@@ -1559,7 +1320,7 @@ async def _rebuild_single_relationship(
chunk_ids: list[str],
chunk_relationships: dict,
llm_response_cache: BaseKVStorage,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
relation_chunks_storage: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
pipeline_status: dict | None = None,
@@ -1584,18 +1345,20 @@ async def _rebuild_single_relationship(
await relation_chunks_storage.upsert(
{
storage_key: {
- 'chunk_ids': normalized_chunk_ids,
- 'count': len(normalized_chunk_ids),
+ "chunk_ids": normalized_chunk_ids,
+ "count": len(normalized_chunk_ids),
}
}
)
- limit_method = global_config.get('source_ids_limit_method') or SOURCE_IDS_LIMIT_METHOD_KEEP
+ limit_method = (
+ global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP
+ )
limited_chunk_ids = apply_source_ids_limit(
normalized_chunk_ids,
- global_config['max_source_ids_per_relation'],
+ global_config["max_source_ids_per_relation"],
limit_method,
- identifier=f'`{src}`~`{tgt}`',
+ identifier=f"`{src}`~`{tgt}`",
)
# Collect all relationship data from relevant chunks
@@ -1605,10 +1368,12 @@ async def _rebuild_single_relationship(
# Check both (src, tgt) and (tgt, src) since relationships can be bidirectional
for edge_key in [(src, tgt), (tgt, src)]:
if edge_key in chunk_relationships[chunk_id]:
- all_relationship_data.extend(chunk_relationships[chunk_id][edge_key])
+ all_relationship_data.extend(
+ chunk_relationships[chunk_id][edge_key]
+ )
if not all_relationship_data:
- logger.warning(f'No relation data found for `{src}-{tgt}`')
+ logger.warning(f"No relation data found for `{src}-{tgt}`")
return
# Merge descriptions and keywords
@@ -1619,22 +1384,24 @@ async def _rebuild_single_relationship(
seen_paths = set()
for rel_data in all_relationship_data:
- if rel_data.get('description'):
- descriptions.append(rel_data['description'])
- if rel_data.get('keywords'):
- keywords.append(rel_data['keywords'])
- if rel_data.get('weight'):
- weights.append(rel_data['weight'])
- if rel_data.get('file_path'):
- file_path = rel_data['file_path']
+ if rel_data.get("description"):
+ descriptions.append(rel_data["description"])
+ if rel_data.get("keywords"):
+ keywords.append(rel_data["keywords"])
+ if rel_data.get("weight"):
+ weights.append(rel_data["weight"])
+ if rel_data.get("file_path"):
+ file_path = rel_data["file_path"]
if file_path and file_path not in seen_paths:
file_paths_list.append(file_path)
seen_paths.add(file_path)
# Apply count limit
- max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
- limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD)
+ max_file_paths = global_config.get("max_file_paths")
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
+ limit_method = global_config.get("source_ids_limit_method")
original_count = len(file_paths_list)
if original_count > max_file_paths:
@@ -1645,22 +1412,30 @@ async def _rebuild_single_relationship(
# KEEP: keep head (earliest), discard tail
file_paths_list = file_paths_list[:max_file_paths]
- file_paths_list.append(f'...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})')
- logger.info(f'Limited `{src}`~`{tgt}`: file_path {original_count} -> {max_file_paths} ({limit_method})')
+ file_paths_list.append(
+ f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})"
+ )
+ logger.info(
+ f"Limited `{src}`~`{tgt}`: file_path {original_count} -> {max_file_paths} ({limit_method})"
+ )
# Remove duplicates while preserving order
description_list = list(dict.fromkeys(descriptions))
keywords = list(dict.fromkeys(keywords))
- combined_keywords = ', '.join(set(keywords)) if keywords else current_relationship.get('keywords', '')
+ combined_keywords = (
+ ", ".join(set(keywords))
+ if keywords
+ else current_relationship.get("keywords", "")
+ )
- weight = sum(weights) if weights else current_relationship.get('weight', 1.0)
+ weight = sum(weights) if weights else current_relationship.get("weight", 1.0)
# Generate final description from relations or fallback to current
if description_list:
final_description, _ = await _handle_entity_relation_summary(
- 'Relation',
- f'{src}-{tgt}',
+ "Relation",
+ f"{src}-{tgt}",
description_list,
GRAPH_FIELD_SEP,
global_config,
@@ -1668,47 +1443,51 @@ async def _rebuild_single_relationship(
)
else:
# fallback to keep current(unchanged)
- final_description = current_relationship.get('description', '')
+ final_description = current_relationship.get("description", "")
if len(limited_chunk_ids) < len(normalized_chunk_ids):
- truncation_info = f'{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}'
+ truncation_info = (
+ f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}"
+ )
else:
- truncation_info = ''
+ truncation_info = ""
# Update relationship in graph storage
updated_relationship_data = {
**current_relationship,
- 'description': final_description if final_description else current_relationship.get('description', ''),
- 'keywords': combined_keywords,
- 'weight': weight,
- 'source_id': GRAPH_FIELD_SEP.join(limited_chunk_ids),
- 'file_path': GRAPH_FIELD_SEP.join([fp for fp in file_paths_list if fp])
+ "description": final_description
+ if final_description
+ else current_relationship.get("description", ""),
+ "keywords": combined_keywords,
+ "weight": weight,
+ "source_id": GRAPH_FIELD_SEP.join(limited_chunk_ids),
+ "file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths_list if fp])
if file_paths_list
- else current_relationship.get('file_path', 'unknown_source'),
- 'truncate': truncation_info,
+ else current_relationship.get("file_path", "unknown_source"),
+ "truncate": truncation_info,
}
# Ensure both endpoint nodes exist before writing the edge back
# (certain storage backends require pre-existing nodes).
node_description = (
- updated_relationship_data['description']
- if updated_relationship_data.get('description')
- else current_relationship.get('description', '')
+ updated_relationship_data["description"]
+ if updated_relationship_data.get("description")
+ else current_relationship.get("description", "")
)
- node_source_id = updated_relationship_data.get('source_id', '')
- node_file_path = updated_relationship_data.get('file_path', 'unknown_source')
+ node_source_id = updated_relationship_data.get("source_id", "")
+ node_file_path = updated_relationship_data.get("file_path", "unknown_source")
for node_id in {src, tgt}:
if not (await knowledge_graph_inst.has_node(node_id)):
node_created_at = int(time.time())
node_data = {
- 'entity_id': node_id,
- 'source_id': node_source_id,
- 'description': node_description,
- 'entity_type': 'UNKNOWN',
- 'file_path': node_file_path,
- 'created_at': node_created_at,
- 'truncate': '',
+ "entity_id": node_id,
+ "source_id": node_source_id,
+ "description": node_description,
+ "entity_type": "UNKNOWN",
+ "file_path": node_file_path,
+ "created_at": node_created_at,
+ "truncate": "",
}
await knowledge_graph_inst.upsert_node(node_id, node_data=node_data)
@@ -1717,28 +1496,28 @@ async def _rebuild_single_relationship(
await entity_chunks_storage.upsert(
{
node_id: {
- 'chunk_ids': limited_chunk_ids,
- 'count': len(limited_chunk_ids),
+ "chunk_ids": limited_chunk_ids,
+ "count": len(limited_chunk_ids),
}
}
)
# Update entity_vdb for the newly created entity
if entities_vdb is not None:
- entity_vdb_id = compute_mdhash_id(node_id, prefix='ent-')
- entity_content = f'{node_id}\n{node_description}'
+ entity_vdb_id = compute_mdhash_id(node_id, prefix="ent-")
+ entity_content = f"{node_id}\n{node_description}"
vdb_data = {
entity_vdb_id: {
- 'content': entity_content,
- 'entity_name': node_id,
- 'source_id': node_source_id,
- 'entity_type': 'UNKNOWN',
- 'file_path': node_file_path,
+ "content": entity_content,
+ "entity_name": node_id,
+ "source_id": node_source_id,
+ "entity_type": "UNKNOWN",
+ "file_path": node_file_path,
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: entities_vdb.upsert(payload),
- operation_name='rebuild_added_entity_upsert',
+ operation_name="rebuild_added_entity_upsert",
entity_name=node_id,
max_retries=3,
retry_delay=0.1,
@@ -1751,211 +1530,64 @@ async def _rebuild_single_relationship(
if src > tgt:
src, tgt = tgt, src
try:
- rel_vdb_id = compute_mdhash_id(src + tgt, prefix='rel-')
- rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix='rel-')
+ rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-")
+ rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-")
# Delete old vector records first (both directions to be safe)
try:
await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse])
except Exception as e:
- logger.debug(f'Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}')
+ logger.debug(
+ f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}"
+ )
# Insert new vector record
- rel_content = f'{combined_keywords}\t{src}\n{tgt}\n{final_description}'
+ rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}"
vdb_data = {
rel_vdb_id: {
- 'src_id': src,
- 'tgt_id': tgt,
- 'source_id': updated_relationship_data['source_id'],
- 'content': rel_content,
- 'keywords': combined_keywords,
- 'description': final_description,
- 'weight': weight,
- 'file_path': updated_relationship_data['file_path'],
+ "src_id": src,
+ "tgt_id": tgt,
+ "source_id": updated_relationship_data["source_id"],
+ "content": rel_content,
+ "keywords": combined_keywords,
+ "description": final_description,
+ "weight": weight,
+ "file_path": updated_relationship_data["file_path"],
}
}
# Use safe operation wrapper - VDB failure must throw exception
await safe_vdb_operation_with_exception(
operation=lambda: relationships_vdb.upsert(vdb_data),
- operation_name='rebuild_relationship_upsert',
- entity_name=f'{src}-{tgt}',
+ operation_name="rebuild_relationship_upsert",
+ entity_name=f"{src}-{tgt}",
max_retries=3,
retry_delay=0.2,
)
except Exception as e:
- error_msg = f'Failed to rebuild relationship storage for `{src}-{tgt}`: {e}'
+ error_msg = f"Failed to rebuild relationship storage for `{src}-{tgt}`: {e}"
logger.error(error_msg)
raise # Re-raise exception
# Log rebuild completion with truncation info
- status_message = f'Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks'
+ status_message = f"Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks"
if truncation_info:
- status_message += f' ({truncation_info})'
- elif len(limited_chunk_ids) < len(normalized_chunk_ids):
- status_message += f' ({limit_method}:{len(limited_chunk_ids)}/{len(normalized_chunk_ids)})'
+ status_message += f" ({truncation_info})"
+ # Add truncation info from apply_source_ids_limit if truncation occurred
+ if len(limited_chunk_ids) < len(normalized_chunk_ids):
+ truncation_info = (
+ f" ({limit_method}:{len(limited_chunk_ids)}/{len(normalized_chunk_ids)})"
+ )
+ status_message += truncation_info
logger.info(status_message)
# Update pipeline status
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
-
-
-def _has_different_numeric_suffix(name_a: str, name_b: str) -> bool:
- """Check if two names have different numeric components.
-
- This prevents false fuzzy matches between entities that differ only by number,
- such as "Interleukin-4" vs "Interleukin-13" (88.9% similar but semantically distinct).
-
- Scientific/medical entities often use numbers as key identifiers:
- - Interleukins: IL-4, IL-13, IL-17
- - Drug phases: Phase 1, Phase 2, Phase 3
- - Receptor types: Type 1, Type 2
- - Versions: v1.0, v2.0
-
- Args:
- name_a: First entity name
- name_b: Second entity name
-
- Returns:
- True if both names contain numbers but the numbers differ, False otherwise.
- """
- # Extract all numeric patterns (integers and decimals)
- pattern = r'(\d+(?:\.\d+)?)'
- nums_a = re.findall(pattern, name_a)
- nums_b = re.findall(pattern, name_b)
-
- # If both have numbers and they differ, these are likely distinct entities
- return bool(nums_a and nums_b and nums_a != nums_b)
-
-
-async def _build_pre_resolution_map(
- entity_names: list[str],
- entity_types: dict[str, str],
- entity_vdb,
- llm_fn,
- config: EntityResolutionConfig,
-) -> tuple[dict[str, str], dict[str, float]]:
- """Build resolution map before parallel processing to prevent race conditions.
-
- This function resolves entities against each other within the batch (using
- instant fuzzy matching) and against existing VDB entries. The resulting map
- is applied during parallel entity processing.
-
- Args:
- entity_names: List of entity names to resolve
- entity_types: Dict mapping entity names to their types (e.g., "person", "organization").
- Used to prevent fuzzy matching between entities of different types.
- entity_vdb: Entity vector database for checking existing entities
- llm_fn: LLM function for semantic verification
- config: Entity resolution configuration
-
- Returns:
- Tuple of:
- - resolution_map: Dict mapping original entity names to their resolved canonical names.
- Only entities that need remapping are included.
- - confidence_map: Dict mapping alias to confidence score (1.0 for exact, actual
- similarity for fuzzy, result.confidence for VDB matches).
- """
- resolution_map: dict[str, str] = {}
- confidence_map: dict[str, float] = {}
- # Track canonical entities with their types: [(name, type), ...]
- canonical_entities: list[tuple[str, str]] = []
-
- for entity_name in entity_names:
- normalized = entity_name.lower().strip()
- entity_type = entity_types.get(entity_name, '')
-
- # Skip if already resolved to something in this batch
- if entity_name in resolution_map:
- continue
-
- # Layer 1: Case-insensitive exact match within batch
- matched = False
- for canonical, _canonical_type in canonical_entities:
- if canonical.lower().strip() == normalized:
- resolution_map[entity_name] = canonical
- confidence_map[entity_name] = 1.0 # Exact match = perfect confidence
- logger.debug(f"Pre-resolution (case match): '{entity_name}' → '{canonical}'")
- matched = True
- break
-
- if matched:
- continue
-
- # Layer 2: Fuzzy match within batch (catches typos like Dupixant→Dupixent)
- # Only enabled when config.fuzzy_pre_resolution_enabled is True.
- # Requires: similarity >= threshold AND matching types (or unknown).
- if config.fuzzy_pre_resolution_enabled:
- for canonical, canonical_type in canonical_entities:
- similarity = fuzzy_similarity(entity_name, canonical)
- if similarity >= config.fuzzy_threshold:
- # Type compatibility check: skip if types differ and both known.
- # Empty/unknown types are treated as compatible to avoid
- # blocking legitimate matches when type info is incomplete.
- types_compatible = not entity_type or not canonical_type or entity_type == canonical_type
- if not types_compatible:
- logger.debug(
- f'Pre-resolution (fuzzy {similarity:.2f}): SKIPPED '
- f"'{entity_name}' ({entity_type}) → "
- f"'{canonical}' ({canonical_type}) - type mismatch"
- )
- continue
-
- # Numeric suffix check: skip if names have different numbers
- # This prevents false matches like "Interleukin-4" → "Interleukin-13"
- # where fuzzy similarity is high (88.9%) but entities are distinct
- if _has_different_numeric_suffix(entity_name, canonical):
- logger.debug(
- f'Pre-resolution (fuzzy {similarity:.2f}): SKIPPED '
- f"'{entity_name}' → '{canonical}' - different numeric suffix"
- )
- continue
-
- # Accept the fuzzy match - emit warning for review
- resolution_map[entity_name] = canonical
- confidence_map[entity_name] = similarity # Use actual similarity score
- etype_display = entity_type or 'unknown'
- ctype_display = canonical_type or 'unknown'
- logger.warning(
- f"Fuzzy pre-resolution accepted: '{entity_name}' → "
- f"'{canonical}' (similarity={similarity:.3f}, "
- f'types: {etype_display}→{ctype_display}). '
- f'Review for correctness; adjust fuzzy_threshold or '
- f'disable fuzzy_pre_resolution_enabled if needed.'
- )
- matched = True
- break
-
- if matched:
- continue
-
- # Layer 3: Check existing VDB for cross-document deduplication
- if entity_vdb and llm_fn:
- try:
- result = await resolve_entity_with_vdb(entity_name, entity_vdb, llm_fn, config)
- if result.action == 'match' and result.matched_entity:
- resolution_map[entity_name] = result.matched_entity
- confidence_map[entity_name] = result.confidence # Use VDB result confidence
- # Add canonical from VDB so batch entities can match it.
- # VDB matches don't have type info available, use empty.
- canonical_entities.append((result.matched_entity, ''))
- logger.debug(f"Pre-resolution (VDB {result.method}): '{entity_name}' → '{result.matched_entity}'")
- continue
- except Exception as e:
- logger.debug(f"Pre-resolution VDB check failed for '{entity_name}': {e}")
-
- # No match found - this is a new canonical entity
- canonical_entities.append((entity_name, entity_type))
-
- if resolution_map:
- logger.info(f'Pre-resolution: {len(resolution_map)} entities mapped to canonical forms')
-
- return resolution_map, confidence_map
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
async def _merge_nodes_then_upsert(
@@ -1964,152 +1596,39 @@ async def _merge_nodes_then_upsert(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage | None,
global_config: dict,
- pipeline_status: dict | None = None,
+ pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
- pre_resolution_map: dict[str, str] | None = None,
- prefetched_nodes: dict[str, dict] | None = None,
-) -> tuple[dict, str | None]:
- """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.
-
- Args:
- prefetched_nodes: Optional dict mapping entity names to their existing node data.
- If provided, avoids individual get_node() calls for better performance.
-
- Returns:
- Tuple of (node_data, original_entity_name). original_entity_name is set if
- entity resolution changed the name (e.g., "Dupixant" → "Dupixent"),
- otherwise None.
- """
- original_entity_name = entity_name # Track original before resolution
-
- # Apply pre-resolution map immediately (prevents race conditions in parallel processing)
- pre_resolved = False
- if pre_resolution_map and entity_name in pre_resolution_map:
- entity_name = pre_resolution_map[entity_name]
- pre_resolved = True
- logger.debug(f"Applied pre-resolution: '{original_entity_name}' → '{entity_name}'")
-
- # Entity Resolution: Resolve new entity against existing entities
- # Skip if already pre-resolved (to avoid redundant VDB queries)
- entity_resolution_config_raw = global_config.get('entity_resolution_config')
- entity_resolution_config = None
- if entity_resolution_config_raw:
- # Handle both dict (from asdict() serialization) and EntityResolutionConfig instances
- if isinstance(entity_resolution_config_raw, EntityResolutionConfig):
- entity_resolution_config = entity_resolution_config_raw
- elif isinstance(entity_resolution_config_raw, dict):
- try:
- entity_resolution_config = EntityResolutionConfig(**entity_resolution_config_raw)
- except TypeError as e:
- logger.warning(
- f'Invalid entity_resolution_config: {e}. '
- f'Config: {entity_resolution_config_raw}. Skipping resolution.'
- )
-
- # Safely check if entity resolution is enabled, handling both object and dict forms
- def _is_resolution_enabled(config) -> bool:
- if config is None:
- return False
- if isinstance(config, dict):
- return config.get('enabled', False)
- return getattr(config, 'enabled', False)
-
- # Skip VDB resolution if entity was already pre-resolved (prevents redundant queries)
- if _is_resolution_enabled(entity_resolution_config) and entity_vdb and not pre_resolved:
- resolution_config = cast(EntityResolutionConfig, entity_resolution_config)
- original_name = entity_name
- workspace = global_config.get('workspace', '')
- # Try knowledge_graph_inst.db first (more reliable), fallback to entity_vdb.db
- db = getattr(knowledge_graph_inst, 'db', None) or getattr(entity_vdb, 'db', None)
-
- # Layer 0: Check alias cache first (PostgreSQL-only - requires db connection)
- # Note: Alias caching is only available when using PostgreSQL storage backend
- if db is not None:
- try:
- cached = await get_cached_alias(original_name, db, workspace)
- if cached:
- canonical, method, _ = cached
- logger.debug(f"Alias cache hit: '{original_name}' → '{canonical}' (method: {method})")
- entity_name = canonical
- except Exception as e:
- logger.warning(
- f"Entity resolution cache lookup failed for '{original_name}' "
- f'(workspace: {workspace}): {type(e).__name__}: {e}. '
- 'Continuing without cache.'
- )
-
- # Layers 1-3: Full VDB resolution (if not found in cache)
- if entity_name == original_name:
- llm_fn = global_config.get('llm_model_func')
- if llm_fn:
- try:
- resolution = await resolve_entity_with_vdb(
- entity_name,
- entity_vdb,
- llm_fn,
- resolution_config,
- )
- if resolution.action == 'match' and resolution.matched_entity:
- logger.info(
- f"Entity resolution: '{entity_name}' → '{resolution.matched_entity}' "
- f'(method: {resolution.method}, confidence: {resolution.confidence:.2f})'
- )
- entity_name = resolution.matched_entity
-
- # Store in alias cache for next time (PostgreSQL-only)
- # Note: Alias caching requires PostgreSQL storage backend
- if db is not None:
- try:
- await store_alias(
- original_name,
- entity_name,
- resolution.method,
- resolution.confidence,
- db,
- workspace,
- )
- except Exception as e:
- logger.warning(
- f"Failed to store entity alias '{original_name}' → '{entity_name}' "
- f'(workspace: {workspace}): {type(e).__name__}: {e}. '
- 'Resolution succeeded but cache not updated.'
- )
- except Exception as e:
- logger.warning(
- f"Entity resolution failed for '{original_name}' "
- f'(workspace: {workspace}): {type(e).__name__}: {e}. '
- 'Continuing with original entity name.'
- )
-
+):
+ """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
already_entity_types = []
already_source_ids = []
already_description = []
already_file_paths = []
- # 1. Get existing node data from knowledge graph (use prefetched if available)
- if prefetched_nodes is not None and entity_name in prefetched_nodes:
- already_node = prefetched_nodes[entity_name]
- else:
- # Fallback to individual fetch if not prefetched (e.g., after VDB resolution)
- already_node = await knowledge_graph_inst.get_node(entity_name)
+ # 1. Get existing node data from knowledge graph
+ already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node:
- already_entity_types.append(already_node['entity_type'])
- already_source_ids.extend(already_node['source_id'].split(GRAPH_FIELD_SEP))
- already_file_paths.extend(already_node['file_path'].split(GRAPH_FIELD_SEP))
- already_description.extend(already_node['description'].split(GRAPH_FIELD_SEP))
+ already_entity_types.append(already_node["entity_type"])
+ already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP))
+ already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP))
+ already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP))
- new_source_ids = [dp['source_id'] for dp in nodes_data if dp.get('source_id')]
+ new_source_ids = [dp["source_id"] for dp in nodes_data if dp.get("source_id")]
existing_full_source_ids = []
if entity_chunks_storage is not None:
stored_chunks = await entity_chunks_storage.get_by_id(entity_name)
if stored_chunks and isinstance(stored_chunks, dict):
- existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id]
+ existing_full_source_ids = [
+ chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id
+ ]
if not existing_full_source_ids:
- existing_full_source_ids = [chunk_id for chunk_id in already_source_ids if chunk_id]
+ existing_full_source_ids = [
+ chunk_id for chunk_id in already_source_ids if chunk_id
+ ]
# 2. Merging new source ids with existing ones
full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids)
@@ -2118,23 +1637,20 @@ async def _merge_nodes_then_upsert(
await entity_chunks_storage.upsert(
{
entity_name: {
- 'chunk_ids': full_source_ids,
- 'count': len(full_source_ids),
+ "chunk_ids": full_source_ids,
+ "count": len(full_source_ids),
}
}
)
# 3. Finalize source_id by applying source ids limit
- limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD)
- max_source_limit = int(
- global_config.get('max_source_ids_per_entity', DEFAULT_MAX_SOURCE_IDS_PER_ENTITY)
- or DEFAULT_MAX_SOURCE_IDS_PER_ENTITY
- )
+ limit_method = global_config.get("source_ids_limit_method")
+ max_source_limit = global_config.get("max_source_ids_per_entity")
source_ids = apply_source_ids_limit(
full_source_ids,
max_source_limit,
limit_method,
- identifier=f'`{entity_name}`',
+ identifier=f"`{entity_name}`",
)
# 4. Only keep nodes not filter by apply_source_ids_limit if limit_method is KEEP
@@ -2142,9 +1658,13 @@ async def _merge_nodes_then_upsert(
allowed_source_ids = set(source_ids)
filtered_nodes = []
for dp in nodes_data:
- source_id = dp.get('source_id')
+ source_id = dp.get("source_id")
# Skip descriptions sourced from chunks dropped by the limitation cap
- if source_id and source_id not in allowed_source_ids and source_id not in existing_full_source_ids:
+ if (
+ source_id
+ and source_id not in allowed_source_ids
+ and source_id not in existing_full_source_ids
+ ):
continue
filtered_nodes.append(dp)
nodes_data = filtered_nodes
@@ -2158,19 +1678,25 @@ async def _merge_nodes_then_upsert(
and not nodes_data
):
if already_node:
- logger.info(f'Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}')
+ logger.info(
+ f"Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}"
+ )
existing_node_data = dict(already_node)
- return existing_node_data, None
+ return existing_node_data
else:
- logger.error(f'Internal Error: already_node missing for `{entity_name}`')
- raise ValueError(f'Internal Error: already_node missing for `{entity_name}`')
+ logger.error(f"Internal Error: already_node missing for `{entity_name}`")
+ raise ValueError(
+ f"Internal Error: already_node missing for `{entity_name}`"
+ )
# 6.1 Finalize source_id
source_id = GRAPH_FIELD_SEP.join(source_ids)
# 6.2 Finalize entity type by highest count
entity_type = sorted(
- Counter([dp['entity_type'] for dp in nodes_data] + already_entity_types).items(),
+ Counter(
+ [dp["entity_type"] for dp in nodes_data] + already_entity_types
+ ).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
@@ -2178,7 +1704,7 @@ async def _merge_nodes_then_upsert(
# 7. Deduplicate nodes by description, keeping first occurrence in the same document
unique_nodes = {}
for dp in nodes_data:
- desc = dp.get('description')
+ desc = dp.get("description")
if not desc:
continue
if desc not in unique_nodes:
@@ -2187,25 +1713,25 @@ async def _merge_nodes_then_upsert(
# Sort description by timestamp, then by description length when timestamps are the same
sorted_nodes = sorted(
unique_nodes.values(),
- key=lambda x: (x.get('timestamp', 0), -len(x.get('description', ''))),
+ key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))),
)
- sorted_descriptions = [dp['description'] for dp in sorted_nodes]
+ sorted_descriptions = [dp["description"] for dp in sorted_nodes]
# Combine already_description with sorted new sorted descriptions
description_list = already_description + sorted_descriptions
if not description_list:
- logger.error(f'Entity {entity_name} has no description')
- raise ValueError(f'Entity {entity_name} has no description')
+ logger.error(f"Entity {entity_name} has no description")
+ raise ValueError(f"Entity {entity_name} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during entity summary')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException("User cancelled during entity summary")
# 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary(
- 'Entity',
+ "Entity",
entity_name,
description_list,
GRAPH_FIELD_SEP,
@@ -2218,12 +1744,14 @@ async def _merge_nodes_then_upsert(
seen_paths = set()
has_placeholder = False # Indicating file_path has been truncated before
- max_file_paths = global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
+ max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS)
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
# Collect from already_file_paths, excluding placeholder
for fp in already_file_paths:
- if fp and fp.startswith(f'...{file_path_placeholder}'): # Skip placeholders
+ if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders
has_placeholder = True
continue
if fp and fp not in seen_paths:
@@ -2232,28 +1760,36 @@ async def _merge_nodes_then_upsert(
# Collect from new data
for dp in nodes_data:
- file_path_item = dp.get('file_path')
+ file_path_item = dp.get("file_path")
if file_path_item and file_path_item not in seen_paths:
file_paths_list.append(file_path_item)
seen_paths.add(file_path_item)
# Apply count limit
if len(file_paths_list) > max_file_paths:
- limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
+ limit_method = global_config.get(
+ "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP
+ )
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
# Add + sign to indicate actual file count is higher
- original_count_str = f'{len(file_paths_list)}+' if has_placeholder else str(len(file_paths_list))
+ original_count_str = (
+ f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list))
+ )
if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
# FIFO: keep tail (newest), discard head
file_paths_list = file_paths_list[-max_file_paths:]
- file_paths_list.append(f'...{file_path_placeholder}...(FIFO)')
+ file_paths_list.append(f"...{file_path_placeholder}...(FIFO)")
else:
# KEEP: keep head (earliest), discard tail
file_paths_list = file_paths_list[:max_file_paths]
- file_paths_list.append(f'...{file_path_placeholder}...(KEEP Old)')
+ file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)")
- logger.info(f'Limited `{entity_name}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})')
+ logger.info(
+ f"Limited `{entity_name}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})"
+ )
# Finalize file_path
file_path = GRAPH_FIELD_SEP.join(file_paths_list)
@@ -2261,73 +1797,75 @@ async def _merge_nodes_then_upsert(
num_fragment = len(description_list)
already_fragment = len(already_description)
if llm_was_used:
- status_message = f'LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}'
+ status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}"
else:
- status_message = f'Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}'
+ status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}"
- truncation_info = truncation_info_log = ''
+ truncation_info = truncation_info_log = ""
if len(source_ids) < len(full_source_ids):
# Add truncation info from apply_source_ids_limit if truncation occurred
- truncation_info_log = f'{limit_method} {len(source_ids)}/{len(full_source_ids)}'
- truncation_info = truncation_info_log if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO else 'KEEP Old'
+ truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}"
+ if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
+ truncation_info = truncation_info_log
+ else:
+ truncation_info = "KEEP Old"
deduplicated_num = already_fragment + len(nodes_data) - num_fragment
- dd_message = ''
+ dd_message = ""
if deduplicated_num > 0:
- # Duplicated description detected across multiple chunks for the same entity
- dd_message = f'dd {deduplicated_num}'
+ # Duplicated description detected across multiple trucks for the same entity
+ dd_message = f"dd {deduplicated_num}"
if dd_message or truncation_info_log:
- status_message += f' ({", ".join(filter(None, [truncation_info_log, dd_message]))})'
+ status_message += (
+ f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})"
+ )
# Add message to pipeline satus when merge happens
if already_fragment > 0 or llm_was_used:
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
else:
logger.debug(status_message)
# 11. Update both graph and vector db
- node_data = {
- 'entity_id': entity_name,
- 'entity_type': entity_type,
- 'description': description,
- 'source_id': source_id,
- 'file_path': file_path,
- 'created_at': int(time.time()),
- 'truncate': truncation_info,
- }
+ node_data = dict(
+ entity_id=entity_name,
+ entity_type=entity_type,
+ description=description,
+ source_id=source_id,
+ file_path=file_path,
+ created_at=int(time.time()),
+ truncate=truncation_info,
+ )
await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
- node_data['entity_name'] = entity_name
+ node_data["entity_name"] = entity_name
if entity_vdb is not None:
- entity_vdb_id = compute_mdhash_id(str(entity_name), prefix='ent-')
- entity_content = f'{entity_name}\n{description}'
+ entity_vdb_id = compute_mdhash_id(str(entity_name), prefix="ent-")
+ entity_content = f"{entity_name}\n{description}"
data_for_vdb = {
entity_vdb_id: {
- 'entity_name': entity_name,
- 'entity_type': entity_type,
- 'content': entity_content,
- 'source_id': source_id,
- 'file_path': file_path,
+ "entity_name": entity_name,
+ "entity_type": entity_type,
+ "content": entity_content,
+ "source_id": source_id,
+ "file_path": file_path,
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=data_for_vdb: entity_vdb.upsert(payload),
- operation_name='entity_upsert',
+ operation_name="entity_upsert",
entity_name=entity_name,
max_retries=3,
retry_delay=0.1,
)
-
- # Return original name if resolution changed it, None otherwise
- resolved_from = original_entity_name if entity_name != original_entity_name else None
- return node_data, resolved_from
+ return node_data
async def _merge_edges_then_upsert(
@@ -2338,19 +1876,13 @@ async def _merge_edges_then_upsert(
relationships_vdb: BaseVectorStorage | None,
entity_vdb: BaseVectorStorage | None,
global_config: dict,
- pipeline_status: dict | None = None,
+ pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
- added_entities: list | None = None, # New parameter to track entities added during edge processing
+ added_entities: list = None, # New parameter to track entities added during edge processing
relation_chunks_storage: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
- entity_resolution_map: dict[str, str] | None = None, # Map original→resolved names
):
- # Apply entity resolution mapping to edge endpoints
- if entity_resolution_map:
- src_id = entity_resolution_map.get(src_id, src_id)
- tgt_id = entity_resolution_map.get(tgt_id, tgt_id)
-
if src_id == tgt_id:
return None
@@ -2359,8 +1891,6 @@ async def _merge_edges_then_upsert(
already_source_ids = []
already_description = []
already_keywords = []
- already_relationships = []
- already_types = []
already_file_paths = []
# 1. Get existing edge data from graph storage
@@ -2368,50 +1898,50 @@ async def _merge_edges_then_upsert(
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
# Handle the case where get_edge returns None or missing fields
if already_edge:
- # Get weight with default 1.0 if missing, convert from string if needed
- weight_val = already_edge.get('weight', 1.0)
- if isinstance(weight_val, str):
- try:
- weight_val = float(weight_val)
- except (ValueError, TypeError):
- weight_val = 1.0
- already_weights.append(weight_val)
+ # Get weight with default 1.0 if missing
+ already_weights.append(already_edge.get("weight", 1.0))
# Get source_id with empty string default if missing or None
- if already_edge.get('source_id') is not None:
- already_source_ids.extend(already_edge['source_id'].split(GRAPH_FIELD_SEP))
-
- # Get file_path with empty string default if missing or None
- if already_edge.get('file_path') is not None:
- already_file_paths.extend(already_edge['file_path'].split(GRAPH_FIELD_SEP))
-
- # Get description with empty string default if missing or None
- if already_edge.get('description') is not None:
- already_description.extend(already_edge['description'].split(GRAPH_FIELD_SEP))
-
- # Get keywords with empty string default if missing or None
- if already_edge.get('keywords') is not None:
- already_keywords.extend(split_string_by_multi_markers(already_edge['keywords'], [GRAPH_FIELD_SEP]))
-
- if already_edge.get('relationship') is not None:
- already_relationships.extend(
- split_string_by_multi_markers(already_edge['relationship'], [GRAPH_FIELD_SEP, ','])
+ if already_edge.get("source_id") is not None:
+ already_source_ids.extend(
+ already_edge["source_id"].split(GRAPH_FIELD_SEP)
)
- if already_edge.get('type') is not None:
- already_types.extend(split_string_by_multi_markers(already_edge['type'], [GRAPH_FIELD_SEP, ',']))
+ # Get file_path with empty string default if missing or None
+ if already_edge.get("file_path") is not None:
+ already_file_paths.extend(
+ already_edge["file_path"].split(GRAPH_FIELD_SEP)
+ )
- new_source_ids = [dp['source_id'] for dp in edges_data if dp.get('source_id')]
+ # Get description with empty string default if missing or None
+ if already_edge.get("description") is not None:
+ already_description.extend(
+ already_edge["description"].split(GRAPH_FIELD_SEP)
+ )
+
+ # Get keywords with empty string default if missing or None
+ if already_edge.get("keywords") is not None:
+ already_keywords.extend(
+ split_string_by_multi_markers(
+ already_edge["keywords"], [GRAPH_FIELD_SEP]
+ )
+ )
+
+ new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")]
storage_key = make_relation_chunk_key(src_id, tgt_id)
existing_full_source_ids = []
if relation_chunks_storage is not None:
stored_chunks = await relation_chunks_storage.get_by_id(storage_key)
if stored_chunks and isinstance(stored_chunks, dict):
- existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id]
+ existing_full_source_ids = [
+ chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id
+ ]
if not existing_full_source_ids:
- existing_full_source_ids = [chunk_id for chunk_id in already_source_ids if chunk_id]
+ existing_full_source_ids = [
+ chunk_id for chunk_id in already_source_ids if chunk_id
+ ]
# 2. Merge new source ids with existing ones
full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids)
@@ -2420,32 +1950,37 @@ async def _merge_edges_then_upsert(
await relation_chunks_storage.upsert(
{
storage_key: {
- 'chunk_ids': full_source_ids,
- 'count': len(full_source_ids),
+ "chunk_ids": full_source_ids,
+ "count": len(full_source_ids),
}
}
)
# 3. Finalize source_id by applying source ids limit
- limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD)
- max_source_limit = int(
- global_config.get('max_source_ids_per_relation', DEFAULT_MAX_SOURCE_IDS_PER_RELATION)
- or DEFAULT_MAX_SOURCE_IDS_PER_RELATION
- )
+ limit_method = global_config.get("source_ids_limit_method")
+ max_source_limit = global_config.get("max_source_ids_per_relation")
source_ids = apply_source_ids_limit(
full_source_ids,
max_source_limit,
limit_method,
- identifier=f'`{src_id}`~`{tgt_id}`',
+ identifier=f"`{src_id}`~`{tgt_id}`",
)
+ limit_method = (
+ global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP
+ )
+
# 4. Only keep edges with source_id in the final source_ids list if in KEEP mode
if limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP:
allowed_source_ids = set(source_ids)
filtered_edges = []
for dp in edges_data:
- source_id = dp.get('source_id')
+ source_id = dp.get("source_id")
# Skip relationship fragments sourced from chunks dropped by keep oldest cap
- if source_id and source_id not in allowed_source_ids and source_id not in existing_full_source_ids:
+ if (
+ source_id
+ and source_id not in allowed_source_ids
+ and source_id not in existing_full_source_ids
+ ):
continue
filtered_edges.append(dp)
edges_data = filtered_edges
@@ -2459,53 +1994,44 @@ async def _merge_edges_then_upsert(
and not edges_data
):
if already_edge:
- logger.info(f'Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}')
+ logger.info(
+ f"Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}"
+ )
existing_edge_data = dict(already_edge)
return existing_edge_data
else:
- logger.error(f'Internal Error: already_node missing for `{src_id}`~`{tgt_id}`')
- raise ValueError(f'Internal Error: already_node missing for `{src_id}`~`{tgt_id}`')
+ logger.error(
+ f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`"
+ )
+ raise ValueError(
+ f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`"
+ )
# 6.1 Finalize source_id
source_id = GRAPH_FIELD_SEP.join(source_ids)
# 6.2 Finalize weight by summing new edges and existing weights
- weight = sum([dp['weight'] for dp in edges_data] + already_weights)
+ weight = sum([dp["weight"] for dp in edges_data] + already_weights)
# 6.2 Finalize keywords by merging existing and new keywords
all_keywords = set()
# Process already_keywords (which are comma-separated)
for keyword_str in already_keywords:
if keyword_str: # Skip empty strings
- all_keywords.update(k.strip() for k in keyword_str.split(',') if k.strip())
+ all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip())
# Process new keywords from edges_data
for edge in edges_data:
- if edge.get('keywords'):
- all_keywords.update(k.strip() for k in edge['keywords'].split(',') if k.strip())
+ if edge.get("keywords"):
+ all_keywords.update(
+ k.strip() for k in edge["keywords"].split(",") if k.strip()
+ )
# Join all unique keywords with commas
- keywords = ','.join(sorted(all_keywords))
-
- # 6.3 Finalize relationship/type labels from explicit field or fallback to keywords
- rel_labels = set()
- for edge in edges_data:
- if edge.get('relationship'):
- rel_labels.update(k.strip() for k in edge['relationship'].split(',') if k.strip())
- rel_labels.update(k.strip() for k in already_relationships if k.strip())
- relationship_label = ','.join(sorted(rel_labels))
- if not relationship_label and keywords:
- relationship_label = keywords.split(',')[0]
-
- type_labels = set()
- for edge in edges_data:
- if edge.get('type'):
- type_labels.update(k.strip() for k in edge['type'].split(',') if k.strip())
- type_labels.update(k.strip() for k in already_types if k.strip())
- type_label = ','.join(sorted(type_labels)) or relationship_label
+ keywords = ",".join(sorted(all_keywords))
# 7. Deduplicate by description, keeping first occurrence in the same document
unique_edges = {}
for dp in edges_data:
- description_value = dp.get('description')
+ description_value = dp.get("description")
if not description_value:
continue
if description_value not in unique_edges:
@@ -2514,26 +2040,28 @@ async def _merge_edges_then_upsert(
# Sort description by timestamp, then by description length (largest to smallest) when timestamps are the same
sorted_edges = sorted(
unique_edges.values(),
- key=lambda x: (x.get('timestamp', 0), -len(x.get('description', ''))),
+ key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))),
)
- sorted_descriptions = [dp['description'] for dp in sorted_edges]
+ sorted_descriptions = [dp["description"] for dp in sorted_edges]
# Combine already_description with sorted new descriptions
description_list = already_description + sorted_descriptions
if not description_list:
- logger.error(f'Relation {src_id}~{tgt_id} has no description')
- raise ValueError(f'Relation {src_id}~{tgt_id} has no description')
+ logger.error(f"Relation {src_id}~{tgt_id} has no description")
+ raise ValueError(f"Relation {src_id}~{tgt_id} has no description")
# Check for cancellation before LLM summary
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during relation summary')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException(
+ "User cancelled during relation summary"
+ )
# 8. Get summary description an LLM usage status
description, llm_was_used = await _handle_entity_relation_summary(
- 'Relation',
- f'({src_id}, {tgt_id})',
+ "Relation",
+ f"({src_id}, {tgt_id})",
description_list,
GRAPH_FIELD_SEP,
global_config,
@@ -2545,13 +2073,15 @@ async def _merge_edges_then_upsert(
seen_paths = set()
has_placeholder = False # Track if already_file_paths contains placeholder
- max_file_paths = global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
+ max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS)
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
# Collect from already_file_paths, excluding placeholder
for fp in already_file_paths:
# Check if this is a placeholder record
- if fp and fp.startswith(f'...{file_path_placeholder}'): # Skip placeholders
+ if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders
has_placeholder = True
continue
if fp and fp not in seen_paths:
@@ -2560,32 +2090,38 @@ async def _merge_edges_then_upsert(
# Collect from new data
for dp in edges_data:
- file_path_item = dp.get('file_path')
+ file_path_item = dp.get("file_path")
if file_path_item and file_path_item not in seen_paths:
file_paths_list.append(file_path_item)
seen_paths.add(file_path_item)
# Apply count limit
- max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS)
+ max_file_paths = global_config.get("max_file_paths")
if len(file_paths_list) > max_file_paths:
- limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP)
- file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER)
+ limit_method = global_config.get(
+ "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP
+ )
+ file_path_placeholder = global_config.get(
+ "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER
+ )
# Add + sign to indicate actual file count is higher
- original_count_str = f'{len(file_paths_list)}+' if has_placeholder else str(len(file_paths_list))
+ original_count_str = (
+ f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list))
+ )
if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
# FIFO: keep tail (newest), discard head
file_paths_list = file_paths_list[-max_file_paths:]
- file_paths_list.append(f'...{file_path_placeholder}...(FIFO)')
+ file_paths_list.append(f"...{file_path_placeholder}...(FIFO)")
else:
# KEEP: keep head (earliest), discard tail
file_paths_list = file_paths_list[:max_file_paths]
- file_paths_list.append(f'...{file_path_placeholder}...(KEEP Old)')
+ file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)")
logger.info(
- f'Limited `{src_id}`~`{tgt_id}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})'
+ f"Limited `{src_id}`~`{tgt_id}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})"
)
# Finalize file_path
file_path = GRAPH_FIELD_SEP.join(file_paths_list)
@@ -2594,32 +2130,37 @@ async def _merge_edges_then_upsert(
num_fragment = len(description_list)
already_fragment = len(already_description)
if llm_was_used:
- status_message = f'LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}'
+ status_message = f"LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}"
else:
- status_message = f'Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}'
+ status_message = f"Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}"
- truncation_info = truncation_info_log = ''
+ truncation_info = truncation_info_log = ""
if len(source_ids) < len(full_source_ids):
# Add truncation info from apply_source_ids_limit if truncation occurred
- truncation_info_log = f'{limit_method} {len(source_ids)}/{len(full_source_ids)}'
- truncation_info = truncation_info_log if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO else 'KEEP Old'
+ truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}"
+ if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
+ truncation_info = truncation_info_log
+ else:
+ truncation_info = "KEEP Old"
deduplicated_num = already_fragment + len(edges_data) - num_fragment
- dd_message = ''
+ dd_message = ""
if deduplicated_num > 0:
- # Duplicated description detected across multiple chunks for the same entity
- dd_message = f'dd {deduplicated_num}'
+ # Duplicated description detected across multiple trucks for the same entity
+ dd_message = f"dd {deduplicated_num}"
if dd_message or truncation_info_log:
- status_message += f' ({", ".join(filter(None, [truncation_info_log, dd_message]))})'
+ status_message += (
+ f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})"
+ )
# Add message to pipeline satus when merge happens
if already_fragment > 0 or llm_was_used:
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
else:
logger.debug(status_message)
@@ -2632,13 +2173,13 @@ async def _merge_edges_then_upsert(
# Node doesn't exist - create new node
node_created_at = int(time.time())
node_data = {
- 'entity_id': need_insert_id,
- 'source_id': source_id,
- 'description': description,
- 'entity_type': 'UNKNOWN',
- 'file_path': file_path,
- 'created_at': node_created_at,
- 'truncate': '',
+ "entity_id": need_insert_id,
+ "source_id": source_id,
+ "description": description,
+ "entity_type": "UNKNOWN",
+ "file_path": file_path,
+ "created_at": node_created_at,
+ "truncate": "",
}
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
@@ -2649,27 +2190,27 @@ async def _merge_edges_then_upsert(
await entity_chunks_storage.upsert(
{
need_insert_id: {
- 'chunk_ids': chunk_ids,
- 'count': len(chunk_ids),
+ "chunk_ids": chunk_ids,
+ "count": len(chunk_ids),
}
}
)
if entity_vdb is not None:
- entity_vdb_id = compute_mdhash_id(need_insert_id, prefix='ent-')
- entity_content = f'{need_insert_id}\n{description}'
+ entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-")
+ entity_content = f"{need_insert_id}\n{description}"
vdb_data = {
entity_vdb_id: {
- 'content': entity_content,
- 'entity_name': need_insert_id,
- 'source_id': source_id,
- 'entity_type': 'UNKNOWN',
- 'file_path': file_path,
+ "content": entity_content,
+ "entity_name": need_insert_id,
+ "source_id": source_id,
+ "entity_type": "UNKNOWN",
+ "file_path": file_path,
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: entity_vdb.upsert(payload),
- operation_name='added_entity_upsert',
+ operation_name="added_entity_upsert",
entity_name=need_insert_id,
max_retries=3,
retry_delay=0.1,
@@ -2678,12 +2219,12 @@ async def _merge_edges_then_upsert(
# Track entities added during edge processing
if added_entities is not None:
entity_data = {
- 'entity_name': need_insert_id,
- 'entity_type': 'UNKNOWN',
- 'description': description,
- 'source_id': source_id,
- 'file_path': file_path,
- 'created_at': node_created_at,
+ "entity_name": need_insert_id,
+ "entity_type": "UNKNOWN",
+ "description": description,
+ "source_id": source_id,
+ "file_path": file_path,
+ "created_at": node_created_at,
}
added_entities.append(entity_data)
else:
@@ -2695,68 +2236,87 @@ async def _merge_edges_then_upsert(
if entity_chunks_storage is not None:
stored_chunks = await entity_chunks_storage.get_by_id(need_insert_id)
if stored_chunks and isinstance(stored_chunks, dict):
- existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id]
+ existing_full_source_ids = [
+ chunk_id
+ for chunk_id in stored_chunks.get("chunk_ids", [])
+ if chunk_id
+ ]
# If not in entity_chunks_storage, get from graph database
- if not existing_full_source_ids and existing_node.get('source_id'):
- existing_full_source_ids = existing_node['source_id'].split(GRAPH_FIELD_SEP)
+ if not existing_full_source_ids:
+ if existing_node.get("source_id"):
+ existing_full_source_ids = existing_node["source_id"].split(
+ GRAPH_FIELD_SEP
+ )
# 2. Merge with new source_ids from this relationship
- new_source_ids_from_relation = [chunk_id for chunk_id in source_ids if chunk_id]
- merged_full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids_from_relation)
+ new_source_ids_from_relation = [
+ chunk_id for chunk_id in source_ids if chunk_id
+ ]
+ merged_full_source_ids = merge_source_ids(
+ existing_full_source_ids, new_source_ids_from_relation
+ )
# 3. Save merged full list to entity_chunks_storage (conditional)
- if entity_chunks_storage is not None and merged_full_source_ids != existing_full_source_ids:
+ if (
+ entity_chunks_storage is not None
+ and merged_full_source_ids != existing_full_source_ids
+ ):
updated = True
await entity_chunks_storage.upsert(
{
need_insert_id: {
- 'chunk_ids': merged_full_source_ids,
- 'count': len(merged_full_source_ids),
+ "chunk_ids": merged_full_source_ids,
+ "count": len(merged_full_source_ids),
}
}
)
# 4. Apply source_ids limit for graph and vector db
- limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP)
- max_source_limit = int(
- global_config.get('max_source_ids_per_entity', DEFAULT_MAX_SOURCE_IDS_PER_ENTITY)
- or DEFAULT_MAX_SOURCE_IDS_PER_ENTITY
+ limit_method = global_config.get(
+ "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP
)
+ max_source_limit = global_config.get("max_source_ids_per_entity")
limited_source_ids = apply_source_ids_limit(
merged_full_source_ids,
max_source_limit,
limit_method,
- identifier=f'`{need_insert_id}`',
+ identifier=f"`{need_insert_id}`",
)
# 5. Update graph database and vector database with limited source_ids (conditional)
limited_source_id_str = GRAPH_FIELD_SEP.join(limited_source_ids)
- if limited_source_id_str != existing_node.get('source_id', ''):
+ if limited_source_id_str != existing_node.get("source_id", ""):
updated = True
updated_node_data = {
**existing_node,
- 'source_id': limited_source_id_str,
+ "source_id": limited_source_id_str,
}
- await knowledge_graph_inst.upsert_node(need_insert_id, node_data=updated_node_data)
+ await knowledge_graph_inst.upsert_node(
+ need_insert_id, node_data=updated_node_data
+ )
# Update vector database
if entity_vdb is not None:
- entity_vdb_id = compute_mdhash_id(need_insert_id, prefix='ent-')
- entity_content = f'{need_insert_id}\n{existing_node.get("description", "")}'
+ entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-")
+ entity_content = (
+ f"{need_insert_id}\n{existing_node.get('description', '')}"
+ )
vdb_data = {
entity_vdb_id: {
- 'content': entity_content,
- 'entity_name': need_insert_id,
- 'source_id': limited_source_id_str,
- 'entity_type': existing_node.get('entity_type', 'UNKNOWN'),
- 'file_path': existing_node.get('file_path', 'unknown_source'),
+ "content": entity_content,
+ "entity_name": need_insert_id,
+ "source_id": limited_source_id_str,
+ "entity_type": existing_node.get("entity_type", "UNKNOWN"),
+ "file_path": existing_node.get(
+ "file_path", "unknown_source"
+ ),
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: entity_vdb.upsert(payload),
- operation_name='existing_entity_update',
+ operation_name="existing_entity_update",
entity_name=need_insert_id,
max_retries=3,
retry_delay=0.1,
@@ -2764,74 +2324,70 @@ async def _merge_edges_then_upsert(
# 6. Log once at the end if any update occurred
if updated:
- status_message = f'Chunks appended from relation: `{need_insert_id}`'
+ status_message = f"Chunks appended from relation: `{need_insert_id}`"
logger.info(status_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = status_message
- pipeline_status['history_messages'].append(status_message)
+ pipeline_status["latest_message"] = status_message
+ pipeline_status["history_messages"].append(status_message)
edge_created_at = int(time.time())
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
- edge_data={
- 'weight': str(weight),
- 'description': description,
- 'keywords': keywords,
- 'relationship': relationship_label,
- 'type': type_label,
- 'source_id': source_id,
- 'file_path': file_path,
- 'created_at': str(edge_created_at),
- 'truncate': truncation_info,
- },
+ edge_data=dict(
+ weight=weight,
+ description=description,
+ keywords=keywords,
+ source_id=source_id,
+ file_path=file_path,
+ created_at=edge_created_at,
+ truncate=truncation_info,
+ ),
)
- edge_data = {
- 'src_id': src_id,
- 'tgt_id': tgt_id,
- 'description': description,
- 'keywords': keywords,
- 'relationship': relationship_label,
- 'type': type_label,
- 'source_id': source_id,
- 'file_path': file_path,
- 'created_at': edge_created_at,
- 'truncate': truncation_info,
- 'weight': weight,
- }
+ edge_data = dict(
+ src_id=src_id,
+ tgt_id=tgt_id,
+ description=description,
+ keywords=keywords,
+ source_id=source_id,
+ file_path=file_path,
+ created_at=edge_created_at,
+ truncate=truncation_info,
+ weight=weight,
+ )
# Sort src_id and tgt_id to ensure consistent ordering (smaller string first)
if src_id > tgt_id:
src_id, tgt_id = tgt_id, src_id
if relationships_vdb is not None:
- rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix='rel-')
- rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix='rel-')
+ rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix="rel-")
+ rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix="rel-")
try:
await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse])
except Exception as e:
- logger.debug(f'Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}')
- rel_content = f'{keywords}\t{src_id}\n{tgt_id}\n{description}'
+ logger.debug(
+ f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}"
+ )
+ rel_content = f"{keywords}\t{src_id}\n{tgt_id}\n{description}"
vdb_data = {
rel_vdb_id: {
- 'src_id': src_id,
- 'tgt_id': tgt_id,
- 'relationship': relationship_label,
- 'type': type_label,
- 'source_id': source_id,
- 'content': rel_content,
- 'keywords': keywords,
- 'description': description,
- 'weight': weight,
- 'file_path': file_path,
+ "src_id": src_id,
+ "tgt_id": tgt_id,
+ "source_id": source_id,
+ "content": rel_content,
+ "keywords": keywords,
+ "description": description,
+ "weight": weight,
+ "file_path": file_path,
}
}
await safe_vdb_operation_with_exception(
operation=lambda payload=vdb_data: relationships_vdb.upsert(payload),
- operation_name='relationship_upsert',
- entity_name=f'{src_id}-{tgt_id}',
+ operation_name="relationship_upsert",
+ entity_name=f"{src_id}-{tgt_id}",
max_retries=3,
retry_delay=0.2,
)
@@ -2844,18 +2400,18 @@ async def merge_nodes_and_edges(
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
- global_config: dict[str, Any],
- full_entities_storage: BaseKVStorage | None = None,
- full_relations_storage: BaseKVStorage | None = None,
- doc_id: str | None = None,
- pipeline_status: dict | None = None,
+ global_config: dict[str, str],
+ full_entities_storage: BaseKVStorage = None,
+ full_relations_storage: BaseKVStorage = None,
+ doc_id: str = None,
+ pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
entity_chunks_storage: BaseKVStorage | None = None,
relation_chunks_storage: BaseKVStorage | None = None,
current_file_number: int = 0,
total_files: int = 0,
- file_path: str = 'unknown_source',
+ file_path: str = "unknown_source",
) -> None:
"""Two-phase merge: process all entities first, then all relationships
@@ -2883,18 +2439,11 @@ async def merge_nodes_and_edges(
file_path: File path for logging
"""
- if full_entities_storage is None or full_relations_storage is None:
- raise ValueError('full_entities_storage and full_relations_storage are required for merge operations')
- if pipeline_status is None:
- pipeline_status = {}
- if pipeline_status_lock is None:
- pipeline_status_lock = asyncio.Lock()
-
# Check for cancellation at the start of merge
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during merge phase')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException("User cancelled during merge phase")
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
@@ -2913,126 +2462,41 @@ async def merge_nodes_and_edges(
total_entities_count = len(all_nodes)
total_relations_count = len(all_edges)
- log_message = f'Merging stage {current_file_number}/{total_files}: {file_path}'
+ log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
# Get max async tasks limit from global_config for semaphore control
- graph_max_async = global_config.get('llm_model_max_async', 4) * 2
+ graph_max_async = global_config.get("llm_model_max_async", 4) * 2
semaphore = asyncio.Semaphore(graph_max_async)
- # ===== Pre-Resolution Phase: Build entity resolution map =====
- # This prevents race conditions when parallel workers process similar entities
- # IMPORTANT: Include BOTH entity names AND relation endpoints to catch all duplicates
- pre_resolution_map: dict[str, str] = {}
- entity_resolution_config_raw = global_config.get('entity_resolution_config')
- if entity_resolution_config_raw:
- # Handle both dict (from asdict() serialization) and EntityResolutionConfig instances
- config = None
- if isinstance(entity_resolution_config_raw, EntityResolutionConfig):
- config = entity_resolution_config_raw
- elif isinstance(entity_resolution_config_raw, dict):
- try:
- config = EntityResolutionConfig(**entity_resolution_config_raw)
- except TypeError as e:
- logger.warning(
- f'Invalid entity_resolution_config: {e}. '
- f'Config: {entity_resolution_config_raw}. Skipping resolution.'
- )
- if config and config.enabled:
- llm_fn = global_config.get('llm_model_func')
- # Build entity_types map for type-aware fuzzy matching.
- # Use first non-empty type for entities with multiple occurrences.
- entity_types: dict[str, str] = {}
- for entity_name, entities in all_nodes.items():
- for entity_data in entities:
- etype = entity_data.get('entity_type', '')
- if etype:
- entity_types[entity_name] = etype
- break
-
- # Collect ALL entity names: from entities AND from relation endpoints
- # This ensures relation endpoints like "EU Medicines Agency" get resolved
- # against existing entities like "European Medicines Agency"
- all_entity_names = set(all_nodes.keys())
- for src_id, tgt_id in all_edges:
- all_entity_names.add(src_id)
- all_entity_names.add(tgt_id)
-
- pre_resolution_map, confidence_map = await _build_pre_resolution_map(
- list(all_entity_names),
- entity_types,
- entity_vdb,
- llm_fn,
- config,
- )
-
- # Cache pre-resolution aliases for future lookups (PostgreSQL-only)
- # This ensures aliases discovered during batch processing are available
- # for subsequent document ingestion without re-running resolution
- db = getattr(knowledge_graph_inst, 'db', None)
- if db is not None and pre_resolution_map:
- workspace = global_config.get('workspace', '')
- for alias, canonical in pre_resolution_map.items():
- # Don't cache self-references (entity → itself)
- if alias.lower().strip() != canonical.lower().strip():
- try:
- await store_alias(
- alias=alias,
- canonical=canonical,
- method='pre_resolution',
- confidence=confidence_map.get(alias, 1.0),
- db=db,
- workspace=workspace,
- )
- except Exception as e:
- logger.debug(f"Failed to cache pre-resolution alias '{alias}' → '{canonical}': {e}")
-
# ===== Phase 1: Process all entities concurrently =====
- log_message = f'Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})'
+ log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
-
- # ===== Batch Prefetch: Load existing entity data in single query =====
- # Build list of entity names to prefetch (apply pre-resolution where applicable)
- prefetch_entity_names = []
- for entity_name in all_nodes:
- resolved_name = pre_resolution_map.get(entity_name, entity_name)
- prefetch_entity_names.append(resolved_name)
-
- # Batch fetch existing nodes to avoid N+1 query pattern during parallel processing
- prefetched_nodes: dict[str, dict] = {}
- if prefetch_entity_names:
- try:
- prefetched_nodes = await knowledge_graph_inst.get_nodes_batch(prefetch_entity_names)
- logger.debug(f'Prefetched {len(prefetched_nodes)}/{len(prefetch_entity_names)} existing entities for merge')
- except Exception as e:
- logger.warning(f'Batch entity prefetch failed: {e}. Falling back to individual fetches.')
- prefetched_nodes = {}
-
- # Resolution map to track original→resolved entity names (e.g., "Dupixant"→"Dupixent")
- # This will be used to remap edge endpoints in Phase 2
- entity_resolution_map: dict[str, str] = {}
- resolution_map_lock = asyncio.Lock()
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
async def _locked_process_entity_name(entity_name, entities):
async with semaphore:
# Check for cancellation before processing entity
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during entity merge')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException(
+ "User cancelled during entity merge"
+ )
- workspace = global_config.get('workspace', '')
- namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB'
- async with get_storage_keyed_lock([entity_name], namespace=namespace, enable_logging=False):
+ workspace = global_config.get("workspace", "")
+ namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
+ async with get_storage_keyed_lock(
+ [entity_name], namespace=namespace, enable_logging=False
+ ):
try:
- logger.debug(f'Processing entity {entity_name}')
- entity_data, resolved_from = await _merge_nodes_then_upsert(
+ logger.debug(f"Processing entity {entity_name}")
+ entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
@@ -3042,33 +2506,32 @@ async def merge_nodes_and_edges(
pipeline_status_lock,
llm_response_cache,
entity_chunks_storage,
- pre_resolution_map,
- prefetched_nodes,
)
- # Track resolution mapping for edge remapping in Phase 2
- if resolved_from is not None:
- resolved_to = entity_data.get('entity_name', entity_name)
- async with resolution_map_lock:
- entity_resolution_map[resolved_from] = resolved_to
-
return entity_data
except Exception as e:
- error_msg = f'Error processing entity `{entity_name}`: {e}'
+ error_msg = f"Error processing entity `{entity_name}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
try:
- if pipeline_status is not None and pipeline_status_lock is not None:
+ if (
+ pipeline_status is not None
+ and pipeline_status_lock is not None
+ ):
async with pipeline_status_lock:
- pipeline_status['latest_message'] = error_msg
- pipeline_status['history_messages'].append(error_msg)
+ pipeline_status["latest_message"] = error_msg
+ pipeline_status["history_messages"].append(error_msg)
except Exception as status_error:
- logger.error(f'Failed to update pipeline status: {status_error}')
+ logger.error(
+ f"Failed to update pipeline status: {status_error}"
+ )
# Re-raise the original exception with a prefix
- prefixed_exception = create_prefixed_exception(e, f'`{entity_name}`')
+ prefixed_exception = create_prefixed_exception(
+ e, f"`{entity_name}`"
+ )
raise prefixed_exception from e
# Create entity processing tasks
@@ -3080,7 +2543,9 @@ async def merge_nodes_and_edges(
# Execute entity tasks with error handling
processed_entities = []
if entity_tasks:
- done, pending = await asyncio.wait(entity_tasks, return_when=asyncio.FIRST_EXCEPTION)
+ done, pending = await asyncio.wait(
+ entity_tasks, return_when=asyncio.FIRST_EXCEPTION
+ )
first_exception = None
processed_entities = []
@@ -3109,22 +2574,24 @@ async def merge_nodes_and_edges(
raise first_exception
# ===== Phase 2: Process all relationships concurrently =====
- log_message = f'Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})'
+ log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
async def _locked_process_edges(edge_key, edges):
async with semaphore:
# Check for cancellation before processing edges
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during relation merge')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException(
+ "User cancelled during relation merge"
+ )
- workspace = global_config.get('workspace', '')
- namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB'
+ workspace = global_config.get("workspace", "")
+ namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
async with get_storage_keyed_lock(
@@ -3135,7 +2602,7 @@ async def merge_nodes_and_edges(
try:
added_entities = [] # Track entities added during edge processing
- logger.debug(f'Processing relation {sorted_edge_key}')
+ logger.debug(f"Processing relation {sorted_edge_key}")
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
@@ -3150,7 +2617,6 @@ async def merge_nodes_and_edges(
added_entities, # Pass list to collect added entities
relation_chunks_storage,
entity_chunks_storage, # Add entity_chunks_storage parameter
- entity_resolution_map, # Apply entity resolution to edge endpoints
)
if edge_data is None:
@@ -3159,50 +2625,33 @@ async def merge_nodes_and_edges(
return edge_data, added_entities
except Exception as e:
- error_msg = f'Error processing relation `{sorted_edge_key}`: {e}'
+ error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
try:
- if pipeline_status is not None and pipeline_status_lock is not None:
+ if (
+ pipeline_status is not None
+ and pipeline_status_lock is not None
+ ):
async with pipeline_status_lock:
- pipeline_status['latest_message'] = error_msg
- pipeline_status['history_messages'].append(error_msg)
+ pipeline_status["latest_message"] = error_msg
+ pipeline_status["history_messages"].append(error_msg)
except Exception as status_error:
- logger.error(f'Failed to update pipeline status: {status_error}')
+ logger.error(
+ f"Failed to update pipeline status: {status_error}"
+ )
# Re-raise the original exception with a prefix
- prefixed_exception = create_prefixed_exception(e, f'{sorted_edge_key}')
+ prefixed_exception = create_prefixed_exception(
+ e, f"{sorted_edge_key}"
+ )
raise prefixed_exception from e
# Create relationship processing tasks
- # Apply pre_resolution_map to edge endpoints to prevent duplicates from relation extraction
- # Key fixes: sort for lock ordering, filter self-loops, deduplicate merged edges
- resolved_edges: dict[tuple[str, str], list] = {}
- for edge_key, edges in all_edges.items():
- # Remap edge endpoints using pre-resolution map
- # This catches cases like "EU Medicines Agency" → "European Medicines Agency"
- resolved_src = pre_resolution_map.get(edge_key[0], edge_key[0]) or ''
- resolved_tgt = pre_resolution_map.get(edge_key[1], edge_key[1]) or ''
-
- # Skip self-loops created by resolution (e.g., both endpoints resolve to same entity)
- if resolved_src == resolved_tgt:
- logger.debug(f'Skipping self-loop after resolution: {edge_key} → ({resolved_src}, {resolved_tgt})')
- continue
-
- # Sort for consistent lock ordering (prevents deadlocks)
- sorted_edge = sorted([resolved_src, resolved_tgt])
- resolved_edge_key: tuple[str, str] = (sorted_edge[0], sorted_edge[1])
-
- # Merge edges that resolve to same key (deduplication)
- if resolved_edge_key not in resolved_edges:
- resolved_edges[resolved_edge_key] = []
- resolved_edges[resolved_edge_key].extend(edges)
-
- # Create tasks from deduplicated edges
edge_tasks = []
- for resolved_edge_key, merged_edges in resolved_edges.items():
- task = asyncio.create_task(_locked_process_edges(resolved_edge_key, merged_edges))
+ for edge_key, edges in all_edges.items():
+ task = asyncio.create_task(_locked_process_edges(edge_key, edges))
edge_tasks.append(task)
# Execute relationship tasks with error handling
@@ -3210,7 +2659,9 @@ async def merge_nodes_and_edges(
all_added_entities = []
if edge_tasks:
- done, pending = await asyncio.wait(edge_tasks, return_when=asyncio.FIRST_EXCEPTION)
+ done, pending = await asyncio.wait(
+ edge_tasks, return_when=asyncio.FIRST_EXCEPTION
+ )
first_exception = None
@@ -3250,37 +2701,37 @@ async def merge_nodes_and_edges(
# Add original processed entities
for entity_data in processed_entities:
- if entity_data and entity_data.get('entity_name'):
- final_entity_names.add(entity_data['entity_name'])
+ if entity_data and entity_data.get("entity_name"):
+ final_entity_names.add(entity_data["entity_name"])
# Add entities that were added during relationship processing
for added_entity in all_added_entities:
- if added_entity and added_entity.get('entity_name'):
- final_entity_names.add(added_entity['entity_name'])
+ if added_entity and added_entity.get("entity_name"):
+ final_entity_names.add(added_entity["entity_name"])
# Collect all relation pairs
final_relation_pairs = set()
for edge_data in processed_edges:
if edge_data:
- src_id = edge_data.get('src_id')
- tgt_id = edge_data.get('tgt_id')
+ src_id = edge_data.get("src_id")
+ tgt_id = edge_data.get("tgt_id")
if src_id and tgt_id:
relation_pair = tuple(sorted([src_id, tgt_id]))
final_relation_pairs.add(relation_pair)
- log_message = f'Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}'
+ log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}"
logger.info(log_message)
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
# Update storage
if final_entity_names:
await full_entities_storage.upsert(
{
doc_id: {
- 'entity_names': list(final_entity_names),
- 'count': len(final_entity_names),
+ "entity_names": list(final_entity_names),
+ "count": len(final_entity_names),
}
}
)
@@ -3289,71 +2740,75 @@ async def merge_nodes_and_edges(
await full_relations_storage.upsert(
{
doc_id: {
- 'relation_pairs': [list(pair) for pair in final_relation_pairs],
- 'count': len(final_relation_pairs),
+ "relation_pairs": [
+ list(pair) for pair in final_relation_pairs
+ ],
+ "count": len(final_relation_pairs),
}
}
)
logger.debug(
- f'Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations'
+ f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations"
)
except Exception as e:
- logger.error(f'Failed to update entity-relation index for document {doc_id}: {e}')
+ logger.error(
+ f"Failed to update entity-relation index for document {doc_id}: {e}"
+ )
# Don't raise exception to avoid affecting main flow
- log_message = f'Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations'
+ log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations"
logger.info(log_message)
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
async def extract_entities(
chunks: dict[str, TextChunkSchema],
- global_config: dict[str, Any],
- pipeline_status: dict | None = None,
+ global_config: dict[str, str],
+ pipeline_status: dict = None,
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
) -> list:
- if pipeline_status is None:
- pipeline_status = {}
- if pipeline_status_lock is None:
- pipeline_status_lock = asyncio.Lock()
# Check for cancellation at the start of entity extraction
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during entity extraction')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException(
+ "User cancelled during entity extraction"
+ )
- use_llm_func: Callable[..., Any] = global_config['llm_model_func']
- entity_extract_max_gleaning = global_config['entity_extract_max_gleaning']
+ use_llm_func: callable = global_config["llm_model_func"]
+ entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
ordered_chunks = list(chunks.items())
# add language and example number params to prompt
- language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE)
- entity_types = global_config['addon_params'].get('entity_types', DEFAULT_ENTITY_TYPES)
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
+ entity_types = global_config["addon_params"].get(
+ "entity_types", DEFAULT_ENTITY_TYPES
+ )
- examples = '\n'.join(PROMPTS['entity_extraction_examples'])
+ examples = "\n".join(PROMPTS["entity_extraction_examples"])
- example_context_base = {
- 'tuple_delimiter': PROMPTS['DEFAULT_TUPLE_DELIMITER'],
- 'completion_delimiter': PROMPTS['DEFAULT_COMPLETION_DELIMITER'],
- 'entity_types': ', '.join(entity_types),
- 'language': language,
- }
+ example_context_base = dict(
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
+ entity_types=", ".join(entity_types),
+ language=language,
+ )
# add example's format
examples = examples.format(**example_context_base)
- context_base = {
- 'tuple_delimiter': PROMPTS['DEFAULT_TUPLE_DELIMITER'],
- 'completion_delimiter': PROMPTS['DEFAULT_COMPLETION_DELIMITER'],
- 'entity_types': ','.join(entity_types),
- 'examples': examples,
- 'language': language,
- }
+ context_base = dict(
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
+ entity_types=",".join(entity_types),
+ examples=examples,
+ language=language,
+ )
processed_chunks = 0
total_chunks = len(ordered_chunks)
@@ -3369,35 +2824,39 @@ async def extract_entities(
nonlocal processed_chunks
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
- content = chunk_dp.get('content', '')
+ content = chunk_dp["content"]
# Get file path from chunk data or use default
- file_path = chunk_dp.get('file_path') or 'unknown_source'
+ file_path = chunk_dp.get("file_path", "unknown_source")
# Create cache keys collector for batch processing
cache_keys_collector = []
# Get initial extraction
- entity_extraction_system_prompt = PROMPTS['entity_extraction_system_prompt'].format(
- **{**context_base, 'input_text': content}
- )
- entity_extraction_user_prompt = PROMPTS['entity_extraction_user_prompt'].format(
- **{**context_base, 'input_text': content}
- )
- entity_continue_extraction_user_prompt = PROMPTS['entity_continue_extraction_user_prompt'].format(
- **{**context_base, 'input_text': content}
+ # Format system prompt without input_text for each chunk (enables OpenAI prompt caching across chunks)
+ entity_extraction_system_prompt = PROMPTS[
+ "entity_extraction_system_prompt"
+ ].format(**context_base)
+ # Format user prompts with input_text for each chunk
+ entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format(
+ **{**context_base, "input_text": content}
)
+ entity_continue_extraction_user_prompt = PROMPTS[
+ "entity_continue_extraction_user_prompt"
+ ].format(**{**context_base, "input_text": content})
final_result, timestamp = await use_llm_func_with_cache(
entity_extraction_user_prompt,
use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
- cache_type='extract',
+ cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
)
- history = pack_user_ass_to_openai_messages(entity_extraction_user_prompt, final_result)
+ history = pack_user_ass_to_openai_messages(
+ entity_extraction_user_prompt, final_result
+ )
# Process initial extraction with file path
maybe_nodes, maybe_edges = await _process_extraction_result(
@@ -3405,8 +2864,8 @@ async def extract_entities(
chunk_key,
timestamp,
file_path,
- tuple_delimiter=context_base['tuple_delimiter'],
- completion_delimiter=context_base['completion_delimiter'],
+ tuple_delimiter=context_base["tuple_delimiter"],
+ completion_delimiter=context_base["completion_delimiter"],
)
# Process additional gleaning results only 1 time when entity_extract_max_gleaning is greater than zero.
@@ -3417,7 +2876,7 @@ async def extract_entities(
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
history_messages=history,
- cache_type='extract',
+ cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
)
@@ -3428,16 +2887,18 @@ async def extract_entities(
chunk_key,
timestamp,
file_path,
- tuple_delimiter=context_base['tuple_delimiter'],
- completion_delimiter=context_base['completion_delimiter'],
+ tuple_delimiter=context_base["tuple_delimiter"],
+ completion_delimiter=context_base["completion_delimiter"],
)
# Merge results - compare description lengths to choose better version
for entity_name, glean_entities in glean_nodes.items():
if entity_name in maybe_nodes:
# Compare description lengths and keep the better one
- original_desc_len = len(maybe_nodes[entity_name][0].get('description', '') or '')
- glean_desc_len = len(glean_entities[0].get('description', '') or '')
+ original_desc_len = len(
+ maybe_nodes[entity_name][0].get("description", "") or ""
+ )
+ glean_desc_len = len(glean_entities[0].get("description", "") or "")
if glean_desc_len > original_desc_len:
maybe_nodes[entity_name] = list(glean_entities)
@@ -3446,18 +2907,20 @@ async def extract_entities(
# New entity from gleaning stage
maybe_nodes[entity_name] = list(glean_entities)
- for edge_key, glean_edge_list in glean_edges.items():
+ for edge_key, glean_edges in glean_edges.items():
if edge_key in maybe_edges:
# Compare description lengths and keep the better one
- original_desc_len = len(maybe_edges[edge_key][0].get('description', '') or '')
- glean_desc_len = len(glean_edge_list[0].get('description', '') or '')
+ original_desc_len = len(
+ maybe_edges[edge_key][0].get("description", "") or ""
+ )
+ glean_desc_len = len(glean_edges[0].get("description", "") or "")
if glean_desc_len > original_desc_len:
- maybe_edges[edge_key] = list(glean_edge_list)
+ maybe_edges[edge_key] = list(glean_edges)
# Otherwise keep original version
else:
# New edge from gleaning stage
- maybe_edges[edge_key] = list(glean_edge_list)
+ maybe_edges[edge_key] = list(glean_edges)
# Batch update chunk's llm_cache_list with all collected cache keys
if cache_keys_collector and text_chunks_storage:
@@ -3465,24 +2928,24 @@ async def extract_entities(
chunk_key,
text_chunks_storage,
cache_keys_collector,
- 'entity_extraction',
+ "entity_extraction",
)
processed_chunks += 1
entities_count = len(maybe_nodes)
relations_count = len(maybe_edges)
- log_message = f'Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}'
+ log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
- pipeline_status['latest_message'] = log_message
- pipeline_status['history_messages'].append(log_message)
+ pipeline_status["latest_message"] = log_message
+ pipeline_status["history_messages"].append(log_message)
# Return the extracted nodes and edges for centralized processing
return maybe_nodes, maybe_edges
# Get max async tasks limit from global_config
- chunk_max_async = global_config.get('llm_model_max_async', 4)
+ chunk_max_async = global_config.get("llm_model_max_async", 4)
semaphore = asyncio.Semaphore(chunk_max_async)
async def _process_with_semaphore(chunk):
@@ -3490,8 +2953,10 @@ async def extract_entities(
# Check for cancellation before processing chunk
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
- if pipeline_status.get('cancellation_requested', False):
- raise PipelineCancelledException('User cancelled during chunk processing')
+ if pipeline_status.get("cancellation_requested", False):
+ raise PipelineCancelledException(
+ "User cancelled during chunk processing"
+ )
try:
return await _process_single_content(chunk)
@@ -3536,7 +3001,7 @@ async def extract_entities(
await asyncio.wait(pending)
# Add progress prefix to the exception message
- progress_prefix = f'C[{processed_chunks + 1}/{total_chunks}]'
+ progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]"
# Re-raise the original exception with a prefix
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
@@ -3554,10 +3019,10 @@ async def kg_query(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
- chunks_vdb: BaseVectorStorage | None = None,
+ chunks_vdb: BaseVectorStorage = None,
) -> QueryResult | None:
"""
Execute knowledge graph query and return unified QueryResult object.
@@ -3590,35 +3055,36 @@ async def kg_query(
Returns None when no relevant context could be constructed for the query.
"""
if not query:
- return QueryResult(content=PROMPTS['fail_response'])
+ return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
else:
- use_model_func = global_config['llm_model_func']
+ use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
- llm_callable = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], use_model_func)
- hl_keywords, ll_keywords = await get_keywords_from_query(query, query_param, global_config, hashing_kv)
+ hl_keywords, ll_keywords = await get_keywords_from_query(
+ query, query_param, global_config, hashing_kv
+ )
- logger.debug(f'High-level keywords: {hl_keywords}')
- logger.debug(f'Low-level keywords: {ll_keywords}')
+ logger.debug(f"High-level keywords: {hl_keywords}")
+ logger.debug(f"Low-level keywords: {ll_keywords}")
# Handle empty keywords
- if ll_keywords == [] and query_param.mode in ['local', 'hybrid', 'mix']:
- logger.warning('low_level_keywords is empty')
- if hl_keywords == [] and query_param.mode in ['global', 'hybrid', 'mix']:
- logger.warning('high_level_keywords is empty')
+ if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]:
+ logger.warning("low_level_keywords is empty")
+ if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]:
+ logger.warning("high_level_keywords is empty")
if hl_keywords == [] and ll_keywords == []:
if len(query) < 50:
- logger.warning(f'Forced low_level_keywords to origin query: {query}')
+ logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
- return QueryResult(content=PROMPTS['fail_response'])
+ return QueryResult(content=PROMPTS["fail_response"])
- ll_keywords_str = ', '.join(ll_keywords) if ll_keywords else ''
- hl_keywords_str = ', '.join(hl_keywords) if hl_keywords else ''
+ ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
+ hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Build query context (unified interface)
context_result = await _build_query_context(
@@ -3634,44 +3100,41 @@ async def kg_query(
)
if context_result is None:
- logger.info('[kg_query] No query context could be built; returning no-result.')
+ logger.info("[kg_query] No query context could be built; returning no-result.")
return None
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
- return QueryResult(content=context_result.context, raw_data=context_result.raw_data)
+ return QueryResult(
+ content=context_result.context, raw_data=context_result.raw_data
+ )
- user_prompt = f'\n\n{query_param.user_prompt}' if query_param.user_prompt else 'n/a'
- response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs'
-
- # Build coverage guidance based on context sparsity
- coverage_guidance = (
- PROMPTS['coverage_guidance_limited']
- if context_result.coverage_level == 'limited'
- else PROMPTS['coverage_guidance_good']
+ user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a"
+ response_type = (
+ query_param.response_type
+ if query_param.response_type
+ else "Multiple Paragraphs"
)
- logger.debug(f'[kg_query] Coverage level: {context_result.coverage_level}')
# Build system prompt
- sys_prompt_temp = system_prompt if system_prompt else PROMPTS['rag_response']
+ sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
response_type=response_type,
user_prompt=user_prompt,
context_data=context_result.context,
- coverage_guidance=coverage_guidance,
)
user_query = query
if query_param.only_need_prompt:
- prompt_content = '\n\n'.join([sys_prompt, '---User Query---', user_query])
+ prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
# Call LLM
- tokenizer: Tokenizer = global_config['tokenizer']
+ tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
- f'[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})'
+ f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
)
# Handle cache
@@ -3686,18 +3149,22 @@ async def kg_query(
query_param.max_total_tokens,
hl_keywords_str,
ll_keywords_str,
- query_param.user_prompt or '',
+ query_param.user_prompt or "",
query_param.enable_rerank,
)
- cached_result = await handle_cache(hashing_kv, args_hash, user_query, query_param.mode, cache_type='query')
+ cached_result = await handle_cache(
+ hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
+ )
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
- logger.info(' == LLM cache == Query cache hit, using cached response as query result')
+ logger.info(
+ " == LLM cache == Query cache hit, using cached response as query result"
+ )
response = cached_response
else:
- response = await llm_callable(
+ response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
@@ -3705,19 +3172,19 @@ async def kg_query(
stream=query_param.stream,
)
- if isinstance(response, str) and hashing_kv and hashing_kv.global_config.get('enable_llm_cache'):
+ if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
- 'mode': query_param.mode,
- 'response_type': query_param.response_type,
- 'top_k': query_param.top_k,
- 'chunk_top_k': query_param.chunk_top_k,
- 'max_entity_tokens': query_param.max_entity_tokens,
- 'max_relation_tokens': query_param.max_relation_tokens,
- 'max_total_tokens': query_param.max_total_tokens,
- 'hl_keywords': hl_keywords_str,
- 'll_keywords': ll_keywords_str,
- 'user_prompt': query_param.user_prompt or '',
- 'enable_rerank': query_param.enable_rerank,
+ "mode": query_param.mode,
+ "response_type": query_param.response_type,
+ "top_k": query_param.top_k,
+ "chunk_top_k": query_param.chunk_top_k,
+ "max_entity_tokens": query_param.max_entity_tokens,
+ "max_relation_tokens": query_param.max_relation_tokens,
+ "max_total_tokens": query_param.max_total_tokens,
+ "hl_keywords": hl_keywords_str,
+ "ll_keywords": ll_keywords_str,
+ "user_prompt": query_param.user_prompt or "",
+ "enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
@@ -3726,7 +3193,7 @@ async def kg_query(
content=response,
prompt=query,
mode=query_param.mode,
- cache_type='query',
+ cache_type="query",
queryparam=queryparam_dict,
),
)
@@ -3736,12 +3203,12 @@ async def kg_query(
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
- response.replace(sys_prompt, '')
- .replace('user', '')
- .replace('model', '')
- .replace(query, '')
- .replace('', '')
- .replace('', '')
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
.strip()
)
@@ -3758,7 +3225,7 @@ async def kg_query(
async def get_keywords_from_query(
query: str,
query_param: QueryParam,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
@@ -3781,14 +3248,16 @@ async def get_keywords_from_query(
return query_param.hl_keywords, query_param.ll_keywords
# Extract keywords using extract_keywords_only function which already supports conversation history
- hl_keywords, ll_keywords = await extract_keywords_only(query, query_param, global_config, hashing_kv)
+ hl_keywords, ll_keywords = await extract_keywords_only(
+ query, query_param, global_config, hashing_kv
+ )
return hl_keywords, ll_keywords
async def extract_keywords_only(
text: str,
param: QueryParam,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> tuple[list[str], list[str]]:
"""
@@ -3797,89 +3266,88 @@ async def extract_keywords_only(
It ONLY extracts keywords (hl_keywords, ll_keywords).
"""
- # 1. Handle cache if needed - add cache type for keywords
+ # 1. Build the examples
+ examples = "\n".join(PROMPTS["keywords_extraction_examples"])
+
+ language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE)
+
+ # 2. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(
param.mode,
text,
+ language,
+ )
+ cached_result = await handle_cache(
+ hashing_kv, args_hash, text, param.mode, cache_type="keywords"
)
- cached_result = await handle_cache(hashing_kv, args_hash, text, param.mode, cache_type='keywords')
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
try:
keywords_data = json_repair.loads(cached_response)
- if isinstance(keywords_data, dict):
- return keywords_data.get('high_level_keywords', []), keywords_data.get('low_level_keywords', [])
+ return keywords_data.get("high_level_keywords", []), keywords_data.get(
+ "low_level_keywords", []
+ )
except (json.JSONDecodeError, KeyError):
- logger.warning('Invalid cache format for keywords, proceeding with extraction')
-
- # 2. Build the examples
- examples = '\n'.join(PROMPTS['keywords_extraction_examples'])
-
- language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE)
+ logger.warning(
+ "Invalid cache format for keywords, proceeding with extraction"
+ )
# 3. Build the keyword-extraction prompt
- kw_prompt = PROMPTS['keywords_extraction'].format(
+ kw_prompt = PROMPTS["keywords_extraction"].format(
query=text,
examples=examples,
language=language,
)
- tokenizer: Tokenizer = global_config['tokenizer']
+ tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(kw_prompt))
- logger.debug(f'[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})')
+ logger.debug(
+ f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})"
+ )
# 4. Call the LLM for keyword extraction
if param.model_func:
use_model_func = param.model_func
else:
- use_model_func = global_config['llm_model_func']
+ use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
- llm_callable = cast(Callable[..., Awaitable[str]], use_model_func)
- result = await llm_callable(kw_prompt, keyword_extraction=True)
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
# 5. Parse out JSON from the LLM response
result = remove_think_tags(result)
try:
keywords_data = json_repair.loads(result)
- if not keywords_data or not isinstance(keywords_data, dict):
- logger.error('No JSON-like structure found in the LLM respond.')
+ if not keywords_data:
+ logger.error("No JSON-like structure found in the LLM respond.")
return [], []
except json.JSONDecodeError as e:
- logger.error(f'JSON parsing error: {e}')
- logger.error(f'LLM respond: {result}')
+ logger.error(f"JSON parsing error: {e}")
+ logger.error(f"LLM respond: {result}")
return [], []
- hl_keywords = keywords_data.get('high_level_keywords', [])
- ll_keywords = keywords_data.get('low_level_keywords', [])
-
- # 5b. Extract years from query that LLM might have missed (defensive heuristic)
- # This applies to ANY domain - financial reports, historical events, etc.
- years_in_query = re.findall(r'\b(19|20)\d{2}\b', text)
- for year in years_in_query:
- if year not in ll_keywords:
- ll_keywords.append(year)
- logger.debug(f'Added year "{year}" to low_level_keywords from query')
+ hl_keywords = keywords_data.get("high_level_keywords", [])
+ ll_keywords = keywords_data.get("low_level_keywords", [])
# 6. Cache only the processed keywords with cache type
if hl_keywords or ll_keywords:
cache_data = {
- 'high_level_keywords': hl_keywords,
- 'low_level_keywords': ll_keywords,
+ "high_level_keywords": hl_keywords,
+ "low_level_keywords": ll_keywords,
}
- if hashing_kv and hashing_kv.global_config.get('enable_llm_cache'):
+ if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
- 'mode': param.mode,
- 'response_type': param.response_type,
- 'top_k': param.top_k,
- 'chunk_top_k': param.chunk_top_k,
- 'max_entity_tokens': param.max_entity_tokens,
- 'max_relation_tokens': param.max_relation_tokens,
- 'max_total_tokens': param.max_total_tokens,
- 'user_prompt': param.user_prompt or '',
- 'enable_rerank': param.enable_rerank,
+ "mode": param.mode,
+ "response_type": param.response_type,
+ "top_k": param.top_k,
+ "chunk_top_k": param.chunk_top_k,
+ "max_entity_tokens": param.max_entity_tokens,
+ "max_relation_tokens": param.max_relation_tokens,
+ "max_total_tokens": param.max_total_tokens,
+ "user_prompt": param.user_prompt or "",
+ "enable_rerank": param.enable_rerank,
}
await save_to_cache(
hashing_kv,
@@ -3888,7 +3356,7 @@ async def extract_keywords_only(
content=json.dumps(cache_data),
prompt=text,
mode=param.mode,
- cache_type='keywords',
+ cache_type="keywords",
queryparam=queryparam_dict,
),
)
@@ -3900,8 +3368,7 @@ async def _get_vector_context(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
- query_embedding: list[float] | None = None,
- entity_keywords: list[str] | None = None,
+ query_embedding: list[float] = None,
) -> list[dict]:
"""
Retrieve text chunks from the vector database without reranking or truncation.
@@ -3909,88 +3376,48 @@ async def _get_vector_context(
This function performs vector search to find relevant text chunks for a query.
Reranking and truncation will be handled later in the unified processing.
- When reranking is enabled, retrieves more candidates (controlled by RETRIEVAL_MULTIPLIER)
- to allow the reranker to surface hidden relevant chunks from beyond the initial top-k.
-
Args:
query: The query string to search for
chunks_vdb: Vector database containing document chunks
query_param: Query parameters including chunk_top_k and ids
query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls
- entity_keywords: Optional list of entity names from query for entity-aware boosting
Returns:
List of text chunks with metadata
"""
try:
- base_top_k = query_param.chunk_top_k or query_param.top_k
+ # Use chunk_top_k if specified, otherwise fall back to top_k
+ search_top_k = query_param.chunk_top_k or query_param.top_k
cosine_threshold = chunks_vdb.cosine_better_than_threshold
- # Two-stage retrieval: when reranking is enabled, retrieve more candidates
- # so the reranker can surface hidden relevant chunks from beyond base_top_k
- if query_param.enable_rerank:
- search_top_k = base_top_k * DEFAULT_RETRIEVAL_MULTIPLIER
- logger.debug(
- f'Two-stage retrieval: {search_top_k} candidates (base={base_top_k}, x{DEFAULT_RETRIEVAL_MULTIPLIER})'
- )
- else:
- search_top_k = base_top_k
-
- # Three-stage retrieval:
- # Stage 1: BM25+vector fusion (if available)
- # Stage 1.5: Entity-aware boosting (if entity_keywords provided)
- # Stage 2: Reranking (happens later in the pipeline)
-
- if hasattr(chunks_vdb, 'hybrid_search_with_entity_boost') and entity_keywords:
- # Use entity-boosted hybrid search when entity keywords are available
- results = await chunks_vdb.hybrid_search_with_entity_boost(
- query,
- top_k=search_top_k,
- entity_keywords=entity_keywords,
- query_embedding=query_embedding,
- )
- search_method = 'hybrid+entity_boost'
- elif hasattr(chunks_vdb, 'hybrid_search'):
- results = await chunks_vdb.hybrid_search(
- query, top_k=search_top_k, query_embedding=query_embedding
- )
- search_method = 'hybrid (BM25+vector)'
- else:
- results = await chunks_vdb.query(
- query, top_k=search_top_k, query_embedding=query_embedding
- )
- search_method = 'vector'
-
+ results = await chunks_vdb.query(
+ query, top_k=search_top_k, query_embedding=query_embedding
+ )
if not results:
- logger.info(f'Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})')
+ logger.info(
+ f"Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})"
+ )
return []
valid_chunks = []
- boosted_count = 0
for result in results:
- if 'content' in result:
+ if "content" in result:
chunk_with_metadata = {
- 'content': result['content'],
- 'created_at': result.get('created_at', None),
- 'file_path': result.get('file_path', 'unknown_source'),
- 'source_type': 'vector', # Mark the source type
- 'chunk_id': result.get('id'), # Add chunk_id for deduplication
+ "content": result["content"],
+ "created_at": result.get("created_at", None),
+ "file_path": result.get("file_path", "unknown_source"),
+ "source_type": "vector", # Mark the source type
+ "chunk_id": result.get("id"), # Add chunk_id for deduplication
}
- if result.get('entity_boosted'):
- chunk_with_metadata['entity_boosted'] = True
- boosted_count += 1
valid_chunks.append(chunk_with_metadata)
- boost_info = f', {boosted_count} entity-boosted' if boosted_count > 0 else ''
- two_stage_info = f' [two-stage: {search_top_k}→{base_top_k}]' if query_param.enable_rerank else ''
logger.info(
- f'Vector retrieval: {len(valid_chunks)} chunks via {search_method} '
- f'(top_k:{search_top_k}{boost_info}{two_stage_info})'
+ f"Naive query: {len(valid_chunks)} chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})"
)
return valid_chunks
except Exception as e:
- logger.error(f'Error in _get_vector_context: {e}')
+ logger.error(f"Error in _get_vector_context: {e}")
return []
@@ -4003,7 +3430,7 @@ async def _perform_kg_search(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- chunks_vdb: BaseVectorStorage | None = None,
+ chunks_vdb: BaseVectorStorage = None,
) -> dict[str, Any]:
"""
Pure search logic that retrieves raw entities, relations, and vector chunks.
@@ -4023,18 +3450,26 @@ async def _perform_kg_search(
# Track chunk sources and metadata for final logging
chunk_tracking = {} # chunk_id -> {source, frequency, order}
- # Pre-compute query embedding once for all vector operations (with caching)
- kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD)
+ # Pre-compute query embedding once for all vector operations
+ kg_chunk_pick_method = text_chunks_db.global_config.get(
+ "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
+ )
query_embedding = None
- if query and (kg_chunk_pick_method == 'VECTOR' or chunks_vdb):
+ if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb):
actual_embedding_func = text_chunks_db.embedding_func
if actual_embedding_func:
- query_embedding = await get_cached_query_embedding(query, actual_embedding_func)
- if query_embedding is not None:
- logger.debug('Pre-computed query embedding for all vector operations')
+ try:
+ query_embedding = await actual_embedding_func([query])
+ query_embedding = query_embedding[
+ 0
+ ] # Extract first embedding from batch result
+ logger.debug("Pre-computed query embedding for all vector operations")
+ except Exception as e:
+ logger.warning(f"Failed to pre-compute query embedding: {e}")
+ query_embedding = None
# Handle local and global modes
- if query_param.mode == 'local' and len(ll_keywords) > 0:
+ if query_param.mode == "local" and len(ll_keywords) > 0:
local_entities, local_relations = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
@@ -4042,7 +3477,7 @@ async def _perform_kg_search(
query_param,
)
- elif query_param.mode == 'global' and len(hl_keywords) > 0:
+ elif query_param.mode == "global" and len(hl_keywords) > 0:
global_relations, global_entities = await _get_edge_data(
hl_keywords,
knowledge_graph_inst,
@@ -4067,33 +3502,24 @@ async def _perform_kg_search(
)
# Get vector chunks for mix mode
- if query_param.mode == 'mix' and chunks_vdb:
- logger.info(f'[MIX DEBUG] Starting vector search for query: {query[:60]}...')
- # Parse ll_keywords into list for entity-aware boosting
- entity_keywords_list = [kw.strip() for kw in ll_keywords.split(',') if kw.strip()] if ll_keywords else None
+ if query_param.mode == "mix" and chunks_vdb:
vector_chunks = await _get_vector_context(
query,
chunks_vdb,
query_param,
query_embedding,
- entity_keywords=entity_keywords_list,
)
- logger.info(f'[MIX DEBUG] Vector search returned {len(vector_chunks)} chunks')
- if not vector_chunks:
- logger.warning(
- f'[MIX DEBUG] ⚠️ NO VECTOR CHUNKS! chunk_top_k={query_param.chunk_top_k}, top_k={query_param.top_k}'
- )
# Track vector chunks with source metadata
for i, chunk in enumerate(vector_chunks):
- chunk_id = chunk.get('chunk_id') or chunk.get('id')
+ chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id:
chunk_tracking[chunk_id] = {
- 'source': 'C',
- 'frequency': 1, # Vector chunks always have frequency 1
- 'order': i + 1, # 1-based order in vector search results
+ "source": "C",
+ "frequency": 1, # Vector chunks always have frequency 1
+ "order": i + 1, # 1-based order in vector search results
}
else:
- logger.warning(f'Vector chunk missing chunk_id: {chunk}')
+ logger.warning(f"Vector chunk missing chunk_id: {chunk}")
# Round-robin merge entities
final_entities = []
@@ -4103,7 +3529,7 @@ async def _perform_kg_search(
# First from local
if i < len(local_entities):
entity = local_entities[i]
- entity_name = entity.get('entity_name')
+ entity_name = entity.get("entity_name")
if entity_name and entity_name not in seen_entities:
final_entities.append(entity)
seen_entities.add(entity_name)
@@ -4111,7 +3537,7 @@ async def _perform_kg_search(
# Then from global
if i < len(global_entities):
entity = global_entities[i]
- entity_name = entity.get('entity_name')
+ entity_name = entity.get("entity_name")
if entity_name and entity_name not in seen_entities:
final_entities.append(entity)
seen_entities.add(entity_name)
@@ -4125,10 +3551,12 @@ async def _perform_kg_search(
if i < len(local_relations):
relation = local_relations[i]
# Build relation unique identifier
- if 'src_tgt' in relation:
- rel_key = tuple(sorted(relation['src_tgt']))
+ if "src_tgt" in relation:
+ rel_key = tuple(sorted(relation["src_tgt"]))
else:
- rel_key = tuple(sorted([relation.get('src_id'), relation.get('tgt_id')]))
+ rel_key = tuple(
+ sorted([relation.get("src_id"), relation.get("tgt_id")])
+ )
if rel_key not in seen_relations:
final_relations.append(relation)
@@ -4138,211 +3566,64 @@ async def _perform_kg_search(
if i < len(global_relations):
relation = global_relations[i]
# Build relation unique identifier
- if 'src_tgt' in relation:
- rel_key = tuple(sorted(relation['src_tgt']))
+ if "src_tgt" in relation:
+ rel_key = tuple(sorted(relation["src_tgt"]))
else:
- rel_key = tuple(sorted([relation.get('src_id'), relation.get('tgt_id')]))
+ rel_key = tuple(
+ sorted([relation.get("src_id"), relation.get("tgt_id")])
+ )
if rel_key not in seen_relations:
final_relations.append(relation)
seen_relations.add(rel_key)
logger.info(
- f'Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks'
+ f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks"
)
return {
- 'final_entities': final_entities,
- 'final_relations': final_relations,
- 'vector_chunks': vector_chunks,
- 'chunk_tracking': chunk_tracking,
- 'query_embedding': query_embedding,
+ "final_entities": final_entities,
+ "final_relations": final_relations,
+ "vector_chunks": vector_chunks,
+ "chunk_tracking": chunk_tracking,
+ "query_embedding": query_embedding,
}
-def _check_topic_connectivity(
- entities_context: list[dict],
- relations_context: list[dict],
- min_relationship_density: float = 0.3,
- min_entity_coverage: float = 0.5,
- min_doc_coverage: float = 0.15, # Lower threshold for document connectivity
-) -> tuple[bool, str]:
- """
- Check if retrieved entities are topically connected.
-
- Connectivity is established if EITHER:
- 1. Entities come from the same source document (file_path) - uses min_doc_coverage threshold
- 2. Entities are connected via graph relationships (BFS) - uses min_entity_coverage threshold
-
- The document coverage threshold (min_doc_coverage) is lower than graph coverage because
- vector search often pulls semantically similar entities from multiple documents.
- Even 15% from a single document indicates strong topical relevance.
-
- Args:
- entities_context: List of entity dictionaries from retrieval
- relations_context: List of relationship dictionaries from retrieval
- min_relationship_density: Minimum ratio of relationships to entities (0.0-1.0)
- min_entity_coverage: Minimum ratio of entities in largest graph component (0.0-1.0)
- min_doc_coverage: Minimum ratio of entities from a single document (0.0-1.0)
-
- Returns:
- (is_connected, reason) where:
- - is_connected: True if topics form a connected graph or share document
- - reason: Explanation if connectivity check failed
- """
- if not entities_context:
- return True, '' # No entities = let existing empty-context logic handle it
-
- # Build set of entity names and track file_path for each (case-insensitive)
- entity_names = set()
- entity_files: dict[str, str] = {} # entity_name -> primary file_path
-
- for e in entities_context:
- name = e.get('entity', e.get('entity_name', '')).lower()
- if not name:
- continue
- entity_names.add(name)
-
- # Extract primary file_path (before if multiple)
- file_path = e.get('file_path', '')
- if file_path:
- primary_file = file_path.split('')[0].strip()
- if primary_file:
- entity_files[name] = primary_file
-
- if not entity_names:
- return True, '' # Can't check without entity names
-
- # Check 1: Document connectivity
- # If entities from the same document meet the coverage threshold, pass
- file_to_entities: dict[str, set[str]] = {}
- for name, file_path in entity_files.items():
- file_to_entities.setdefault(file_path, set()).add(name)
-
- # Debug: Log document distribution
- if file_to_entities:
- doc_summary = ', '.join(
- f'"{fp}": {len(ents)}' for fp, ents in sorted(file_to_entities.items(), key=lambda x: -len(x[1]))[:5]
- )
- logger.info(
- f'Topic connectivity: Document distribution (top 5): [{doc_summary}] total_entities={len(entity_names)}'
- )
- else:
- logger.info(f'Topic connectivity: No file_path data available for {len(entity_names)} entities')
-
- for file_path, file_entities in file_to_entities.items():
- doc_coverage = len(file_entities) / len(entity_names)
- if doc_coverage >= min_doc_coverage:
- logger.info(
- f'Topic connectivity: PASS (document) - {len(file_entities)}/{len(entity_names)} '
- f'entities from "{file_path}" ({doc_coverage:.0%} coverage, threshold={min_doc_coverage:.0%})'
- )
- return True, ''
-
- # Check 2: Graph connectivity (original logic)
- if not relations_context:
- # Log document distribution for debugging
- doc_summary = ', '.join(f'{fp}: {len(ents)}' for fp, ents in file_to_entities.items())
- logger.info(f'Topic connectivity: FAIL - no relations, docs=[{doc_summary}]')
- return False, 'Retrieved entities have no connecting relationships or common document'
-
- # Build adjacency list from relationships
- # Only include edges where BOTH endpoints are in our entity set
- adjacency: dict[str, set[str]] = {name: set() for name in entity_names}
- edges_in_subgraph = 0
-
- for rel in relations_context:
- src = (rel.get('src_id') or rel.get('source') or rel.get('entity1') or '').lower()
- tgt = (rel.get('tgt_id') or rel.get('target') or rel.get('entity2') or '').lower()
-
- # Only count edges where BOTH endpoints are in our retrieved entities
- if src in entity_names and tgt in entity_names and src != tgt:
- adjacency[src].add(tgt)
- adjacency[tgt].add(src)
- edges_in_subgraph += 1
-
- # Check if we have enough edges connecting entities to each other
- if edges_in_subgraph == 0:
- doc_summary = ', '.join(f'{fp}: {len(ents)}' for fp, ents in file_to_entities.items())
- logger.info(
- f'Topic connectivity: FAIL - {len(relations_context)} relations but 0 connect '
- f'retrieved entities to each other, docs=[{doc_summary}]'
- )
- return False, 'Retrieved entities are not connected to each other by relationships or common document'
-
- # Use BFS to find connected components
- visited = set()
- components = []
-
- for start_node in entity_names:
- if start_node in visited:
- continue
-
- # BFS from this node
- component = set()
- queue = [start_node]
- while queue:
- node = queue.pop(0)
- if node in visited:
- continue
- visited.add(node)
- component.add(node)
- for neighbor in adjacency[node]:
- if neighbor not in visited:
- queue.append(neighbor)
-
- components.append(component)
-
- # Check component coverage
- largest_component = max(components, key=len) if components else set()
- coverage = len(largest_component) / len(entity_names) if entity_names else 0
-
- logger.info(
- f'Topic connectivity: {len(entity_names)} entities, {edges_in_subgraph} connecting edges, '
- f'{len(components)} components, largest={len(largest_component)} ({coverage:.0%} coverage)'
- )
-
- if coverage < min_entity_coverage:
- logger.info(f'Topic connectivity: FAIL - graph coverage {coverage:.0%} < threshold {min_entity_coverage:.0%}')
- return False, f'Entities form {len(components)} disconnected clusters (largest covers {coverage:.0%})'
-
- return True, ''
-
-
async def _apply_token_truncation(
search_result: dict[str, Any],
query_param: QueryParam,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
) -> dict[str, Any]:
"""
Apply token-based truncation to entities and relations for LLM efficiency.
"""
- tokenizer = global_config.get('tokenizer')
+ tokenizer = global_config.get("tokenizer")
if not tokenizer:
- logger.warning('No tokenizer found, skipping truncation')
+ logger.warning("No tokenizer found, skipping truncation")
return {
- 'entities_context': [],
- 'relations_context': [],
- 'filtered_entities': search_result['final_entities'],
- 'filtered_relations': search_result['final_relations'],
- 'entity_id_to_original': {},
- 'relation_id_to_original': {},
+ "entities_context": [],
+ "relations_context": [],
+ "filtered_entities": search_result["final_entities"],
+ "filtered_relations": search_result["final_relations"],
+ "entity_id_to_original": {},
+ "relation_id_to_original": {},
}
# Get token limits from query_param with fallbacks
max_entity_tokens = getattr(
query_param,
- 'max_entity_tokens',
- global_config.get('max_entity_tokens', DEFAULT_MAX_ENTITY_TOKENS),
+ "max_entity_tokens",
+ global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS),
)
max_relation_tokens = getattr(
query_param,
- 'max_relation_tokens',
- global_config.get('max_relation_tokens', DEFAULT_MAX_RELATION_TOKENS),
+ "max_relation_tokens",
+ global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS),
)
- final_entities = search_result['final_entities']
- final_relations = search_result['final_relations']
+ final_entities = search_result["final_entities"]
+ final_relations = search_result["final_relations"]
# Create mappings from entity/relation identifiers to original data
entity_id_to_original = {}
@@ -4350,37 +3631,37 @@ async def _apply_token_truncation(
# Generate entities context for truncation
entities_context = []
- for _i, entity in enumerate(final_entities):
- entity_name = entity['entity_name']
- created_at = entity.get('created_at', 'UNKNOWN')
+ for i, entity in enumerate(final_entities):
+ entity_name = entity["entity_name"]
+ created_at = entity.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
- created_at = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(created_at))
+ created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Store mapping from entity name to original data
entity_id_to_original[entity_name] = entity
entities_context.append(
{
- 'entity': entity_name,
- 'type': entity.get('entity_type', 'UNKNOWN'),
- 'description': entity.get('description', 'UNKNOWN'),
- 'created_at': created_at,
- 'file_path': entity.get('file_path', 'unknown_source'),
+ "entity": entity_name,
+ "type": entity.get("entity_type", "UNKNOWN"),
+ "description": entity.get("description", "UNKNOWN"),
+ "created_at": created_at,
+ "file_path": entity.get("file_path", "unknown_source"),
}
)
# Generate relations context for truncation
relations_context = []
- for _i, relation in enumerate(final_relations):
- created_at = relation.get('created_at', 'UNKNOWN')
+ for i, relation in enumerate(final_relations):
+ created_at = relation.get("created_at", "UNKNOWN")
if isinstance(created_at, (int, float)):
- created_at = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(created_at))
+ created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at))
# Handle different relation data formats
- if 'src_tgt' in relation:
- entity1, entity2 = relation['src_tgt']
+ if "src_tgt" in relation:
+ entity1, entity2 = relation["src_tgt"]
else:
- entity1, entity2 = relation.get('src_id'), relation.get('tgt_id')
+ entity1, entity2 = relation.get("src_id"), relation.get("tgt_id")
# Store mapping from relation pair to original data
relation_key = (entity1, entity2)
@@ -4388,103 +3669,67 @@ async def _apply_token_truncation(
relations_context.append(
{
- 'entity1': entity1,
- 'entity2': entity2,
- 'description': relation.get('description', 'UNKNOWN'),
- 'created_at': created_at,
- 'file_path': relation.get('file_path', 'unknown_source'),
+ "entity1": entity1,
+ "entity2": entity2,
+ "description": relation.get("description", "UNKNOWN"),
+ "created_at": created_at,
+ "file_path": relation.get("file_path", "unknown_source"),
}
)
- original_entity_count = len(entities_context)
- original_relation_count = len(relations_context)
+ logger.debug(
+ f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations"
+ )
- # Check if truncation is disabled
- disable_truncation = getattr(query_param, 'disable_truncation', False)
-
- # Track truncated counts for logging
- truncated_entity_count = original_entity_count
- truncated_relation_count = original_relation_count
-
- # Apply token-based truncation (safety ceiling, not aggressive limiting)
- # Top-k already limits retrieval count; this is a fallback for very long descriptions
+ # Apply token-based truncation
if entities_context:
- # Remove file_path and created_at for token calculation (metadata not sent to LLM)
+ # Remove file_path and created_at for token calculation
entities_context_for_truncation = []
for entity in entities_context:
entity_copy = entity.copy()
- entity_copy.pop('file_path', None)
- entity_copy.pop('created_at', None)
+ entity_copy.pop("file_path", None)
+ entity_copy.pop("created_at", None)
entities_context_for_truncation.append(entity_copy)
- truncated_entities = truncate_list_by_token_size(
+ entities_context = truncate_list_by_token_size(
entities_context_for_truncation,
- key=lambda x: '\n'.join(json.dumps(item, ensure_ascii=False) for item in [x]),
+ key=lambda x: "\n".join(
+ json.dumps(item, ensure_ascii=False) for item in [x]
+ ),
max_token_size=max_entity_tokens,
tokenizer=tokenizer,
)
- # Restore file_path from original entities_context (needed for connectivity check)
- entity_name_to_file_path = {e['entity']: e.get('file_path', '') for e in entities_context}
- for entity in truncated_entities:
- entity['file_path'] = entity_name_to_file_path.get(entity['entity'], '')
-
- truncated_entity_count = len(truncated_entities)
- if not disable_truncation:
- entities_context = truncated_entities
if relations_context:
# Remove file_path and created_at for token calculation
relations_context_for_truncation = []
for relation in relations_context:
relation_copy = relation.copy()
- relation_copy.pop('file_path', None)
- relation_copy.pop('created_at', None)
+ relation_copy.pop("file_path", None)
+ relation_copy.pop("created_at", None)
relations_context_for_truncation.append(relation_copy)
- truncated_relations = truncate_list_by_token_size(
+ relations_context = truncate_list_by_token_size(
relations_context_for_truncation,
- key=lambda x: '\n'.join(json.dumps(item, ensure_ascii=False) for item in [x]),
+ key=lambda x: "\n".join(
+ json.dumps(item, ensure_ascii=False) for item in [x]
+ ),
max_token_size=max_relation_tokens,
tokenizer=tokenizer,
)
- # Restore file_path from original relations_context (for consistency)
- relation_key_to_file_path = {(r['entity1'], r['entity2']): r.get('file_path', '') for r in relations_context}
- for relation in truncated_relations:
- relation['file_path'] = relation_key_to_file_path.get((relation['entity1'], relation['entity2']), '')
- truncated_relation_count = len(truncated_relations)
- if not disable_truncation:
- relations_context = truncated_relations
-
- # Calculate how many would be dropped
- entities_would_drop = original_entity_count - truncated_entity_count
- relations_would_drop = original_relation_count - truncated_relation_count
-
- # Log if token ceiling was/would be hit
- if entities_would_drop > 0 or relations_would_drop > 0:
- if disable_truncation:
- logger.warning(
- f'Token ceiling exceeded (truncation disabled): would have dropped '
- f'{entities_would_drop} entities, {relations_would_drop} relations. '
- f'Keeping all {original_entity_count} entities, {original_relation_count} relations.'
- )
- else:
- logger.warning(
- f'Token ceiling hit: dropped {entities_would_drop} entities (limit={max_entity_tokens} tokens), '
- f'{relations_would_drop} relations (limit={max_relation_tokens} tokens). '
- f'Increase token limits or set disable_truncation=True.'
- )
-
- logger.info(f'Context: {len(entities_context)} entities, {len(relations_context)} relations')
+ logger.info(
+ f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations"
+ )
# Create filtered original data based on truncated context
filtered_entities = []
filtered_entity_id_to_original = {}
if entities_context:
- final_entity_names = {e['entity'] for e in entities_context}
+ final_entity_names = {e["entity"] for e in entities_context}
seen_nodes = set()
for entity in final_entities:
- name = entity.get('entity_name')
+ name = entity.get("entity_name")
if name in final_entity_names and name not in seen_nodes:
filtered_entities.append(entity)
filtered_entity_id_to_original[name] = entity
@@ -4493,12 +3738,12 @@ async def _apply_token_truncation(
filtered_relations = []
filtered_relation_id_to_original = {}
if relations_context:
- final_relation_pairs = {(r['entity1'], r['entity2']) for r in relations_context}
+ final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context}
seen_edges = set()
for relation in final_relations:
- src, tgt = relation.get('src_id'), relation.get('tgt_id')
+ src, tgt = relation.get("src_id"), relation.get("tgt_id")
if src is None or tgt is None:
- src, tgt = relation.get('src_tgt', (None, None))
+ src, tgt = relation.get("src_tgt", (None, None))
pair = (src, tgt)
if pair in final_relation_pairs and pair not in seen_edges:
@@ -4507,12 +3752,12 @@ async def _apply_token_truncation(
seen_edges.add(pair)
return {
- 'entities_context': entities_context,
- 'relations_context': relations_context,
- 'filtered_entities': filtered_entities,
- 'filtered_relations': filtered_relations,
- 'entity_id_to_original': filtered_entity_id_to_original,
- 'relation_id_to_original': filtered_relation_id_to_original,
+ "entities_context": entities_context,
+ "relations_context": relations_context,
+ "filtered_entities": filtered_entities,
+ "filtered_relations": filtered_relations,
+ "entity_id_to_original": filtered_entity_id_to_original,
+ "relation_id_to_original": filtered_relation_id_to_original,
}
@@ -4520,21 +3765,17 @@ async def _merge_all_chunks(
filtered_entities: list[dict],
filtered_relations: list[dict],
vector_chunks: list[dict],
- query: str = '',
- knowledge_graph_inst: BaseGraphStorage | None = None,
- text_chunks_db: BaseKVStorage | None = None,
- query_param: QueryParam | None = None,
- chunks_vdb: BaseVectorStorage | None = None,
- chunk_tracking: dict | None = None,
- query_embedding: list[float] | None = None,
+ query: str = "",
+ knowledge_graph_inst: BaseGraphStorage = None,
+ text_chunks_db: BaseKVStorage = None,
+ query_param: QueryParam = None,
+ chunks_vdb: BaseVectorStorage = None,
+ chunk_tracking: dict = None,
+ query_embedding: list[float] = None,
) -> list[dict]:
"""
Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks.
"""
- if query_param is None:
- raise ValueError('query_param is required for merging chunks')
- if knowledge_graph_inst is None or chunks_vdb is None:
- raise ValueError('knowledge_graph_inst and chunks_vdb are required for chunk merging')
if chunk_tracking is None:
chunk_tracking = {}
@@ -4576,47 +3817,47 @@ async def _merge_all_chunks(
# Add from vector chunks first (Naive mode)
if i < len(vector_chunks):
chunk = vector_chunks[i]
- chunk_id = chunk.get('chunk_id') or chunk.get('id')
+ chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id and chunk_id not in seen_chunk_ids:
seen_chunk_ids.add(chunk_id)
merged_chunks.append(
{
- 'content': chunk['content'],
- 'file_path': chunk.get('file_path', 'unknown_source'),
- 'chunk_id': chunk_id,
+ "content": chunk["content"],
+ "file_path": chunk.get("file_path", "unknown_source"),
+ "chunk_id": chunk_id,
}
)
# Add from entity chunks (Local mode)
if i < len(entity_chunks):
chunk = entity_chunks[i]
- chunk_id = chunk.get('chunk_id') or chunk.get('id')
+ chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id and chunk_id not in seen_chunk_ids:
seen_chunk_ids.add(chunk_id)
merged_chunks.append(
{
- 'content': chunk['content'],
- 'file_path': chunk.get('file_path', 'unknown_source'),
- 'chunk_id': chunk_id,
+ "content": chunk["content"],
+ "file_path": chunk.get("file_path", "unknown_source"),
+ "chunk_id": chunk_id,
}
)
# Add from relation chunks (Global mode)
if i < len(relation_chunks):
chunk = relation_chunks[i]
- chunk_id = chunk.get('chunk_id') or chunk.get('id')
+ chunk_id = chunk.get("chunk_id") or chunk.get("id")
if chunk_id and chunk_id not in seen_chunk_ids:
seen_chunk_ids.add(chunk_id)
merged_chunks.append(
{
- 'content': chunk['content'],
- 'file_path': chunk.get('file_path', 'unknown_source'),
- 'chunk_id': chunk_id,
+ "content": chunk["content"],
+ "file_path": chunk.get("file_path", "unknown_source"),
+ "chunk_id": chunk_id,
}
)
logger.info(
- f'Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})'
+ f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})"
)
return merged_chunks
@@ -4628,18 +3869,18 @@ async def _build_context_str(
merged_chunks: list[dict],
query: str,
query_param: QueryParam,
- global_config: dict[str, Any],
- chunk_tracking: dict | None = None,
- entity_id_to_original: dict | None = None,
- relation_id_to_original: dict | None = None,
+ global_config: dict[str, str],
+ chunk_tracking: dict = None,
+ entity_id_to_original: dict = None,
+ relation_id_to_original: dict = None,
) -> tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
"""
- tokenizer = global_config.get('tokenizer')
+ tokenizer = global_config.get("tokenizer")
if not tokenizer:
- logger.error('Missing tokenizer, cannot build LLM context')
+ logger.error("Missing tokenizer, cannot build LLM context")
# Return empty raw data structure when no tokenizer
empty_raw_data = convert_to_user_format(
[],
@@ -4648,49 +3889,64 @@ async def _build_context_str(
[],
query_param.mode,
)
- empty_raw_data['status'] = 'failure'
- empty_raw_data['message'] = 'Missing tokenizer, cannot build LLM context.'
- return '', empty_raw_data
+ empty_raw_data["status"] = "failure"
+ empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
+ return "", empty_raw_data
# Get token limits
max_total_tokens = getattr(
query_param,
- 'max_total_tokens',
- global_config.get('max_total_tokens', DEFAULT_MAX_TOTAL_TOKENS),
+ "max_total_tokens",
+ global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Get the system prompt template from PROMPTS or global_config
- sys_prompt_template = global_config.get('system_prompt_template', PROMPTS['rag_response'])
+ sys_prompt_template = global_config.get(
+ "system_prompt_template", PROMPTS["rag_response"]
+ )
- kg_context_template = PROMPTS['kg_query_context']
- user_prompt = query_param.user_prompt if query_param.user_prompt else ''
- response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs'
+ kg_context_template = PROMPTS["kg_query_context"]
+ user_prompt = query_param.user_prompt if query_param.user_prompt else ""
+ response_type = (
+ query_param.response_type
+ if query_param.response_type
+ else "Multiple Paragraphs"
+ )
- entities_str = '\n'.join(json.dumps(entity, ensure_ascii=False) for entity in entities_context)
- relations_str = '\n'.join(json.dumps(relation, ensure_ascii=False) for relation in relations_context)
+ entities_str = "\n".join(
+ json.dumps(entity, ensure_ascii=False) for entity in entities_context
+ )
+ relations_str = "\n".join(
+ json.dumps(relation, ensure_ascii=False) for relation in relations_context
+ )
# Calculate preliminary kg context tokens
pre_kg_context = kg_context_template.format(
entities_str=entities_str,
relations_str=relations_str,
- text_chunks_str='',
- reference_list_str='',
+ text_chunks_str="",
+ reference_list_str="",
)
kg_context_tokens = len(tokenizer.encode(pre_kg_context))
# Calculate preliminary system prompt tokens
pre_sys_prompt = sys_prompt_template.format(
- context_data='', # Empty for overhead calculation
+ context_data="", # Empty for overhead calculation
response_type=response_type,
user_prompt=user_prompt,
- coverage_guidance='', # Empty for overhead calculation
)
sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt))
# Calculate available tokens for text chunks
query_tokens = len(tokenizer.encode(query))
buffer_tokens = 200 # reserved for reference list and safety buffer
- available_chunk_tokens = max_total_tokens - (sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens)
+ available_chunk_tokens = max_total_tokens - (
+ sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens
+ )
+
+ logger.debug(
+ f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
+ )
# Apply token truncation to chunks using the dynamic limit
truncated_chunks = await process_chunks_unified(
@@ -4703,26 +3959,32 @@ async def _build_context_str(
)
# Generate reference list from truncated chunks using the new common function
- reference_list, truncated_chunks = generate_reference_list_from_chunks(truncated_chunks)
+ reference_list, truncated_chunks = generate_reference_list_from_chunks(
+ truncated_chunks
+ )
# Rebuild chunks_context with truncated chunks
# The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic
chunks_context = []
- for _i, chunk in enumerate(truncated_chunks):
+ for i, chunk in enumerate(truncated_chunks):
chunks_context.append(
{
- 'reference_id': chunk['reference_id'],
- 'content': chunk['content'],
+ "reference_id": chunk["reference_id"],
+ "content": chunk["content"],
}
)
- text_units_str = '\n'.join(json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context)
- reference_list_str = '\n'.join(
- f'[{ref["reference_id"]}] {ref["file_path"]}' for ref in reference_list if ref['reference_id']
+ text_units_str = "\n".join(
+ json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context
+ )
+ reference_list_str = "\n".join(
+ f"[{ref['reference_id']}] {ref['file_path']}"
+ for ref in reference_list
+ if ref["reference_id"]
)
logger.info(
- f'Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks'
+ f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks"
)
# not necessary to use LLM to generate a response
@@ -4735,27 +3997,27 @@ async def _build_context_str(
[],
query_param.mode,
)
- empty_raw_data['status'] = 'failure'
- empty_raw_data['message'] = 'Query returned empty dataset.'
- return '', empty_raw_data
+ empty_raw_data["status"] = "failure"
+ empty_raw_data["message"] = "Query returned empty dataset."
+ return "", empty_raw_data
# output chunks tracking infomations
# format: / (e.g., E5/2 R2/1 C1/1)
if truncated_chunks and chunk_tracking:
chunk_tracking_log = []
for chunk in truncated_chunks:
- chunk_id = chunk.get('chunk_id')
+ chunk_id = chunk.get("chunk_id")
if chunk_id and chunk_id in chunk_tracking:
tracking_info = chunk_tracking[chunk_id]
- source = tracking_info['source']
- frequency = tracking_info['frequency']
- order = tracking_info['order']
- chunk_tracking_log.append(f'{source}{frequency}/{order}')
+ source = tracking_info["source"]
+ frequency = tracking_info["frequency"]
+ order = tracking_info["order"]
+ chunk_tracking_log.append(f"{source}{frequency}/{order}")
else:
- chunk_tracking_log.append('?0/0')
+ chunk_tracking_log.append("?0/0")
if chunk_tracking_log:
- logger.info(f'Final chunks S+F/O: {" ".join(chunk_tracking_log)}')
+ logger.info(f"Final chunks S+F/O: {' '.join(chunk_tracking_log)}")
result = kg_context_template.format(
entities_str=entities_str,
@@ -4766,7 +4028,7 @@ async def _build_context_str(
# Always return both context and complete data structure (unified approach)
logger.debug(
- f'[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks'
+ f"[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
)
final_data = convert_to_user_format(
entities_context,
@@ -4778,7 +4040,7 @@ async def _build_context_str(
relation_id_to_original,
)
logger.debug(
- f'[_build_context_str] Final data after conversion: {len(final_data.get("entities", []))} entities, {len(final_data.get("relationships", []))} relationships, {len(final_data.get("chunks", []))} chunks'
+ f"[_build_context_str] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
)
return result, final_data
@@ -4793,7 +4055,7 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
- chunks_vdb: BaseVectorStorage | None = None,
+ chunks_vdb: BaseVectorStorage = None,
) -> QueryContextResult | None:
"""
Main query context building function using the new 4-stage architecture:
@@ -4803,7 +4065,7 @@ async def _build_query_context(
"""
if not query:
- logger.warning('Query is empty, skipping context building')
+ logger.warning("Query is empty, skipping context building")
return None
# Stage 1: Pure search
@@ -4819,11 +4081,11 @@ async def _build_query_context(
chunks_vdb,
)
- if not search_result['final_entities'] and not search_result['final_relations']:
- if query_param.mode != 'mix':
+ if not search_result["final_entities"] and not search_result["final_relations"]:
+ if query_param.mode != "mix":
return None
else:
- if not search_result['chunk_tracking']:
+ if not search_result["chunk_tracking"]:
return None
# Stage 2: Apply token truncation for LLM efficiency
@@ -4833,91 +4095,75 @@ async def _build_query_context(
text_chunks_db.global_config,
)
- # Stage 2.5: Check topic connectivity
- # Skip for: naive (no graph), bypass (no retrieval)
- # This prevents hallucination when querying unrelated topics like "diabetes + renewable energy"
- if query_param.check_topic_connectivity and query_param.mode not in ('naive', 'bypass'):
- is_connected, reason = _check_topic_connectivity(
- truncation_result['entities_context'],
- truncation_result['relations_context'],
- min_relationship_density=query_param.min_relationship_density,
- min_entity_coverage=query_param.min_entity_coverage,
- )
- if not is_connected:
- logger.info(f'Topic connectivity check failed: {reason}')
- return None # Return None to trigger "no context" response
-
# Stage 3: Merge chunks using filtered entities/relations
merged_chunks = await _merge_all_chunks(
- filtered_entities=truncation_result['filtered_entities'],
- filtered_relations=truncation_result['filtered_relations'],
- vector_chunks=search_result['vector_chunks'],
+ filtered_entities=truncation_result["filtered_entities"],
+ filtered_relations=truncation_result["filtered_relations"],
+ vector_chunks=search_result["vector_chunks"],
query=query,
knowledge_graph_inst=knowledge_graph_inst,
text_chunks_db=text_chunks_db,
query_param=query_param,
chunks_vdb=chunks_vdb,
- chunk_tracking=search_result['chunk_tracking'],
- query_embedding=search_result['query_embedding'],
+ chunk_tracking=search_result["chunk_tracking"],
+ query_embedding=search_result["query_embedding"],
)
- if not merged_chunks and not truncation_result['entities_context'] and not truncation_result['relations_context']:
+ if (
+ not merged_chunks
+ and not truncation_result["entities_context"]
+ and not truncation_result["relations_context"]
+ ):
return None
# Stage 4: Build final LLM context with dynamic token processing
# _build_context_str now always returns tuple[str, dict]
context, raw_data = await _build_context_str(
- entities_context=truncation_result['entities_context'],
- relations_context=truncation_result['relations_context'],
+ entities_context=truncation_result["entities_context"],
+ relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
- chunk_tracking=search_result['chunk_tracking'],
- entity_id_to_original=truncation_result['entity_id_to_original'],
- relation_id_to_original=truncation_result['relation_id_to_original'],
+ chunk_tracking=search_result["chunk_tracking"],
+ entity_id_to_original=truncation_result["entity_id_to_original"],
+ relation_id_to_original=truncation_result["relation_id_to_original"],
)
# Convert keywords strings to lists and add complete metadata to raw_data
- hl_keywords_list = hl_keywords.split(', ') if hl_keywords else []
- ll_keywords_list = ll_keywords.split(', ') if ll_keywords else []
+ hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
+ ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
- if 'metadata' not in raw_data:
- raw_data['metadata'] = {}
+ if "metadata" not in raw_data:
+ raw_data["metadata"] = {}
# Update keywords while preserving existing metadata
- raw_data['metadata']['keywords'] = {
- 'high_level': hl_keywords_list,
- 'low_level': ll_keywords_list,
+ raw_data["metadata"]["keywords"] = {
+ "high_level": hl_keywords_list,
+ "low_level": ll_keywords_list,
}
- raw_data['metadata']['processing_info'] = {
- 'total_entities_found': len(search_result.get('final_entities', [])),
- 'total_relations_found': len(search_result.get('final_relations', [])),
- 'entities_after_truncation': len(truncation_result.get('filtered_entities', [])),
- 'relations_after_truncation': len(truncation_result.get('filtered_relations', [])),
- 'merged_chunks_count': len(merged_chunks),
- 'final_chunks_count': len(raw_data.get('data', {}).get('chunks', [])),
+ raw_data["metadata"]["processing_info"] = {
+ "total_entities_found": len(search_result.get("final_entities", [])),
+ "total_relations_found": len(search_result.get("final_relations", [])),
+ "entities_after_truncation": len(
+ truncation_result.get("filtered_entities", [])
+ ),
+ "relations_after_truncation": len(
+ truncation_result.get("filtered_relations", [])
+ ),
+ "merged_chunks_count": len(merged_chunks),
+ "final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
}
- # Calculate context sparsity to guide LLM response
- # Sparse context = LLM should be more conservative about inferences
- entity_count = len(truncation_result.get('filtered_entities', []))
- relation_count = len(truncation_result.get('filtered_relations', []))
- chunk_count = len(raw_data.get('data', {}).get('chunks', []))
- is_sparse = entity_count < 3 or relation_count < 2 or chunk_count < 2
- coverage_level = 'limited' if is_sparse else 'good'
-
logger.debug(
- f'[_build_query_context] Context sparsity: entities={entity_count}, relations={relation_count}, chunks={chunk_count} -> coverage={coverage_level}'
+ f"[_build_query_context] Context length: {len(context) if context else 0}"
+ )
+ logger.debug(
+ f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
)
- logger.debug(f'[_build_query_context] Context length: {len(context) if context else 0}')
- logger.debug(
- f'[_build_query_context] Raw data entities: {len(raw_data.get("data", {}).get("entities", []))}, relationships: {len(raw_data.get("data", {}).get("relationships", []))}, chunks: {len(raw_data.get("data", {}).get("chunks", []))}'
- )
-
- return QueryContextResult(context=context, raw_data=raw_data, coverage_level=coverage_level)
+ return QueryContextResult(context=context, raw_data=raw_data)
async def _get_node_data(
@@ -4927,7 +4173,9 @@ async def _get_node_data(
query_param: QueryParam,
):
# get similar entities
- logger.info(f'Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})')
+ logger.info(
+ f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
+ )
results = await entities_vdb.query(query, top_k=query_param.top_k)
@@ -4935,7 +4183,7 @@ async def _get_node_data(
return [], []
# Extract all entity IDs from your results list
- node_ids = [r['entity_name'] for r in results]
+ node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
@@ -4947,17 +4195,17 @@ async def _get_node_data(
node_datas = [nodes_dict.get(nid) for nid in node_ids]
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
- if not all(n is not None for n in node_datas):
- logger.warning('Some nodes are missing, maybe the storage is damaged')
+ if not all([n is not None for n in node_datas]):
+ logger.warning("Some nodes are missing, maybe the storage is damaged")
node_datas = [
{
**n,
- 'entity_name': k['entity_name'],
- 'rank': d,
- 'created_at': k.get('created_at'),
+ "entity_name": k["entity_name"],
+ "rank": d,
+ "created_at": k.get("created_at"),
}
- for k, n, d in zip(results, node_datas, node_degrees, strict=False)
+ for k, n, d in zip(results, node_datas, node_degrees)
if n is not None
]
@@ -4967,7 +4215,9 @@ async def _get_node_data(
knowledge_graph_inst,
)
- logger.info(f'Local query: {len(node_datas)} entites, {len(use_relations)} relations')
+ logger.info(
+ f"Local query: {len(node_datas)} entites, {len(use_relations)} relations"
+ )
# Entities are sorted by cosine similarity
# Relations are sorted by rank + weight
@@ -4979,7 +4229,7 @@ async def _find_most_related_edges_from_entities(
query_param: QueryParam,
knowledge_graph_inst: BaseGraphStorage,
):
- node_names = [dp['entity_name'] for dp in node_datas]
+ node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
all_edges = []
@@ -4995,7 +4245,7 @@ async def _find_most_related_edges_from_entities(
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
- edge_pairs_dicts = [{'src': e[0], 'tgt': e[1]} for e in all_edges]
+ edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
# For edge degrees, use tuples.
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
@@ -5010,18 +4260,22 @@ async def _find_most_related_edges_from_entities(
for pair in all_edges:
edge_props = edge_data_dict.get(pair)
if edge_props is not None:
- if 'weight' not in edge_props:
- logger.warning(f"Edge {pair} missing 'weight' attribute, using default value 1.0")
- edge_props['weight'] = 1.0
+ if "weight" not in edge_props:
+ logger.warning(
+ f"Edge {pair} missing 'weight' attribute, using default value 1.0"
+ )
+ edge_props["weight"] = 1.0
combined = {
- 'src_tgt': pair,
- 'rank': edge_degrees_dict.get(pair, 0),
+ "src_tgt": pair,
+ "rank": edge_degrees_dict.get(pair, 0),
**edge_props,
}
all_edges_data.append(combined)
- all_edges_data = sorted(all_edges_data, key=lambda x: (x['rank'], x['weight']), reverse=True)
+ all_edges_data = sorted(
+ all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
+ )
return all_edges_data
@@ -5031,9 +4285,9 @@ async def _find_related_text_unit_from_entities(
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
knowledge_graph_inst: BaseGraphStorage,
- query: str | None = None,
- chunks_vdb: BaseVectorStorage | None = None,
- chunk_tracking: dict | None = None,
+ query: str = None,
+ chunks_vdb: BaseVectorStorage = None,
+ chunk_tracking: dict = None,
query_embedding=None,
):
"""
@@ -5043,7 +4297,7 @@ async def _find_related_text_unit_from_entities(
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
- logger.debug(f'Finding text chunks from {len(node_datas)} entities')
+ logger.debug(f"Finding text chunks from {len(node_datas)} entities")
if not node_datas:
return []
@@ -5051,30 +4305,38 @@ async def _find_related_text_unit_from_entities(
# Step 1: Collect all text chunks for each entity
entities_with_chunks = []
for entity in node_datas:
- if entity.get('source_id'):
- chunks = split_string_by_multi_markers(entity['source_id'], [GRAPH_FIELD_SEP])
+ if entity.get("source_id"):
+ chunks = split_string_by_multi_markers(
+ entity["source_id"], [GRAPH_FIELD_SEP]
+ )
if chunks:
entities_with_chunks.append(
{
- 'entity_name': entity['entity_name'],
- 'chunks': chunks,
- 'entity_data': entity,
+ "entity_name": entity["entity_name"],
+ "chunks": chunks,
+ "entity_data": entity,
}
)
if not entities_with_chunks:
- logger.warning('No entities with text chunks found')
+ logger.warning("No entities with text chunks found")
return []
- kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD)
- max_related_chunks = text_chunks_db.global_config.get('related_chunk_number', DEFAULT_RELATED_CHUNK_NUMBER)
+ kg_chunk_pick_method = text_chunks_db.global_config.get(
+ "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
+ )
+ max_related_chunks = text_chunks_db.global_config.get(
+ "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
+ )
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
chunk_occurrence_count = {}
for entity_info in entities_with_chunks:
deduplicated_chunks = []
- for chunk_id in entity_info['chunks']:
- chunk_occurrence_count[chunk_id] = chunk_occurrence_count.get(chunk_id, 0) + 1
+ for chunk_id in entity_info["chunks"]:
+ chunk_occurrence_count[chunk_id] = (
+ chunk_occurrence_count.get(chunk_id, 0) + 1
+ )
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
@@ -5082,17 +4344,17 @@ async def _find_related_text_unit_from_entities(
# count > 1 means this chunk appeared in an earlier entity, so skip it
# Update entity's chunks to deduplicated chunks
- entity_info['chunks'] = deduplicated_chunks
+ entity_info["chunks"] = deduplicated_chunks
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
total_entity_chunks = 0
for entity_info in entities_with_chunks:
sorted_chunks = sorted(
- entity_info['chunks'],
+ entity_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
- entity_info['sorted_chunks'] = sorted_chunks
+ entity_info["sorted_chunks"] = sorted_chunks
total_entity_chunks += len(sorted_chunks)
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
@@ -5101,14 +4363,14 @@ async def _find_related_text_unit_from_entities(
# Pick by vector similarity:
# The order of text chunks aligns with the naive retrieval's destination.
# When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval.
- if kg_chunk_pick_method == 'VECTOR' and query and chunks_vdb:
+ if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
# Get embedding function from global config
actual_embedding_func = text_chunks_db.embedding_func
if not actual_embedding_func:
- logger.warning('No embedding function found, falling back to WEIGHT method')
- kg_chunk_pick_method = 'WEIGHT'
+ logger.warning("No embedding function found, falling back to WEIGHT method")
+ kg_chunk_pick_method = "WEIGHT"
else:
try:
selected_chunk_ids = await pick_by_vector_similarity(
@@ -5122,50 +4384,56 @@ async def _find_related_text_unit_from_entities(
)
if selected_chunk_ids == []:
- kg_chunk_pick_method = 'WEIGHT'
+ kg_chunk_pick_method = "WEIGHT"
logger.warning(
- 'No entity-related chunks selected by vector similarity, falling back to WEIGHT method'
+ "No entity-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
- f'Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity'
+ f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity"
)
except Exception as e:
- logger.error(f'Error in vector similarity sorting: {e}, falling back to WEIGHT method')
- kg_chunk_pick_method = 'WEIGHT'
+ logger.error(
+ f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
+ )
+ kg_chunk_pick_method = "WEIGHT"
- if kg_chunk_pick_method == 'WEIGHT':
+ if kg_chunk_pick_method == "WEIGHT":
# Pick by entity and chunk weight:
# When reranking is disabled, delivered more solely KG related chunks to the LLM
- selected_chunk_ids = pick_by_weighted_polling(entities_with_chunks, max_related_chunks, min_related_chunks=1)
+ selected_chunk_ids = pick_by_weighted_polling(
+ entities_with_chunks, max_related_chunks, min_related_chunks=1
+ )
logger.info(
- f'Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling'
+ f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
- unique_chunk_ids = list(dict.fromkeys(selected_chunk_ids)) # Remove duplicates while preserving order
+ unique_chunk_ids = list(
+ dict.fromkeys(selected_chunk_ids)
+ ) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data and update chunk tracking
result_chunks = []
- for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list, strict=False)):
- if chunk_data is not None and 'content' in chunk_data:
+ for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
+ if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
- chunk_data_copy['source_type'] = 'entity'
- chunk_data_copy['chunk_id'] = chunk_id # Add chunk_id for deduplication
+ chunk_data_copy["source_type"] = "entity"
+ chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
# Update chunk tracking if provided
if chunk_tracking is not None:
chunk_tracking[chunk_id] = {
- 'source': 'E',
- 'frequency': chunk_occurrence_count.get(chunk_id, 1),
- 'order': i + 1, # 1-based order in final entity-related results
+ "source": "E",
+ "frequency": chunk_occurrence_count.get(chunk_id, 1),
+ "order": i + 1, # 1-based order in final entity-related results
}
return result_chunks
@@ -5178,7 +4446,7 @@ async def _get_edge_data(
query_param: QueryParam,
):
logger.info(
- f'Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})'
+ f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
@@ -5188,24 +4456,26 @@ async def _get_edge_data(
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
- edge_pairs_dicts = [{'src': r['src_id'], 'tgt': r['tgt_id']} for r in results]
+ edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts)
# Reconstruct edge_datas list in the same order as results.
edge_datas = []
for k in results:
- pair = (k['src_id'], k['tgt_id'])
+ pair = (k["src_id"], k["tgt_id"])
edge_props = edge_data_dict.get(pair)
if edge_props is not None:
- if 'weight' not in edge_props:
- logger.warning(f"Edge {pair} missing 'weight' attribute, using default value 1.0")
- edge_props['weight'] = 1.0
+ if "weight" not in edge_props:
+ logger.warning(
+ f"Edge {pair} missing 'weight' attribute, using default value 1.0"
+ )
+ edge_props["weight"] = 1.0
# Keep edge data without rank, maintain vector search order
combined = {
- 'src_id': k['src_id'],
- 'tgt_id': k['tgt_id'],
- 'created_at': k.get('created_at', None),
+ "src_id": k["src_id"],
+ "tgt_id": k["tgt_id"],
+ "created_at": k.get("created_at", None),
**edge_props,
}
edge_datas.append(combined)
@@ -5218,7 +4488,9 @@ async def _get_edge_data(
knowledge_graph_inst,
)
- logger.info(f'Global query: {len(use_entities)} entites, {len(edge_datas)} relations')
+ logger.info(
+ f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations"
+ )
return edge_datas, use_entities
@@ -5232,12 +4504,12 @@ async def _find_most_related_entities_from_relationships(
seen = set()
for e in edge_datas:
- if e['src_id'] not in seen:
- entity_names.append(e['src_id'])
- seen.add(e['src_id'])
- if e['tgt_id'] not in seen:
- entity_names.append(e['tgt_id'])
- seen.add(e['tgt_id'])
+ if e["src_id"] not in seen:
+ entity_names.append(e["src_id"])
+ seen.add(e["src_id"])
+ if e["tgt_id"] not in seen:
+ entity_names.append(e["tgt_id"])
+ seen.add(e["tgt_id"])
# Only get nodes data, no need for node degrees
nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names)
@@ -5250,7 +4522,7 @@ async def _find_most_related_entities_from_relationships(
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
continue
# Combine the node data with the entity name, no rank needed
- combined = {**node, 'entity_name': entity_name}
+ combined = {**node, "entity_name": entity_name}
node_datas.append(combined)
return node_datas
@@ -5260,10 +4532,10 @@ async def _find_related_text_unit_from_relations(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage,
- entity_chunks: list[dict] | None = None,
- query: str | None = None,
- chunks_vdb: BaseVectorStorage | None = None,
- chunk_tracking: dict | None = None,
+ entity_chunks: list[dict] = None,
+ query: str = None,
+ chunks_vdb: BaseVectorStorage = None,
+ chunk_tracking: dict = None,
query_embedding=None,
):
"""
@@ -5273,7 +4545,7 @@ async def _find_related_text_unit_from_relations(
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
"""
- logger.debug(f'Finding text chunks from {len(edge_datas)} relations')
+ logger.debug(f"Finding text chunks from {len(edge_datas)} relations")
if not edge_datas:
return []
@@ -5281,29 +4553,37 @@ async def _find_related_text_unit_from_relations(
# Step 1: Collect all text chunks for each relationship
relations_with_chunks = []
for relation in edge_datas:
- if relation.get('source_id'):
- chunks = split_string_by_multi_markers(relation['source_id'], [GRAPH_FIELD_SEP])
+ if relation.get("source_id"):
+ chunks = split_string_by_multi_markers(
+ relation["source_id"], [GRAPH_FIELD_SEP]
+ )
if chunks:
# Build relation identifier
- if 'src_tgt' in relation:
- rel_key = tuple(sorted(relation['src_tgt']))
+ if "src_tgt" in relation:
+ rel_key = tuple(sorted(relation["src_tgt"]))
else:
- rel_key = tuple(sorted([relation.get('src_id', ''), relation.get('tgt_id', '')]))
+ rel_key = tuple(
+ sorted([relation.get("src_id"), relation.get("tgt_id")])
+ )
relations_with_chunks.append(
{
- 'relation_key': rel_key,
- 'chunks': chunks,
- 'relation_data': relation,
+ "relation_key": rel_key,
+ "chunks": chunks,
+ "relation_data": relation,
}
)
if not relations_with_chunks:
- logger.warning('No relation-related chunks found')
+ logger.warning("No relation-related chunks found")
return []
- kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD)
- max_related_chunks = text_chunks_db.global_config.get('related_chunk_number', DEFAULT_RELATED_CHUNK_NUMBER)
+ kg_chunk_pick_method = text_chunks_db.global_config.get(
+ "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
+ )
+ max_related_chunks = text_chunks_db.global_config.get(
+ "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
+ )
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
# Also remove duplicates with entity_chunks
@@ -5312,7 +4592,7 @@ async def _find_related_text_unit_from_relations(
entity_chunk_ids = set()
if entity_chunks:
for chunk in entity_chunks:
- chunk_id = chunk.get('chunk_id')
+ chunk_id = chunk.get("chunk_id")
if chunk_id:
entity_chunk_ids.add(chunk_id)
@@ -5322,14 +4602,16 @@ async def _find_related_text_unit_from_relations(
for relation_info in relations_with_chunks:
deduplicated_chunks = []
- for chunk_id in relation_info['chunks']:
+ for chunk_id in relation_info["chunks"]:
# Skip chunks that already exist in entity_chunks
if chunk_id in entity_chunk_ids:
# Only count each unique chunk_id once
removed_entity_chunk_ids.add(chunk_id)
continue
- chunk_occurrence_count[chunk_id] = chunk_occurrence_count.get(chunk_id, 0) + 1
+ chunk_occurrence_count[chunk_id] = (
+ chunk_occurrence_count.get(chunk_id, 0) + 1
+ )
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
if chunk_occurrence_count[chunk_id] == 1:
@@ -5337,41 +4619,47 @@ async def _find_related_text_unit_from_relations(
# count > 1 means this chunk appeared in an earlier relationship, so skip it
# Update relationship's chunks to deduplicated chunks
- relation_info['chunks'] = deduplicated_chunks
+ relation_info["chunks"] = deduplicated_chunks
# Check if any relations still have chunks after deduplication
- relations_with_chunks = [relation_info for relation_info in relations_with_chunks if relation_info['chunks']]
+ relations_with_chunks = [
+ relation_info
+ for relation_info in relations_with_chunks
+ if relation_info["chunks"]
+ ]
if not relations_with_chunks:
- logger.info(f'Find no additional relations-related chunks from {len(edge_datas)} relations')
+ logger.info(
+ f"Find no additional relations-related chunks from {len(edge_datas)} relations"
+ )
return []
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
total_relation_chunks = 0
for relation_info in relations_with_chunks:
sorted_chunks = sorted(
- relation_info['chunks'],
+ relation_info["chunks"],
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
reverse=True,
)
- relation_info['sorted_chunks'] = sorted_chunks
+ relation_info["sorted_chunks"] = sorted_chunks
total_relation_chunks += len(sorted_chunks)
logger.info(
- f'Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})'
+ f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})"
)
# Step 4: Apply the selected chunk selection algorithm
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
- if kg_chunk_pick_method == 'VECTOR' and query and chunks_vdb:
+ if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
# Get embedding function from global config
actual_embedding_func = text_chunks_db.embedding_func
if not actual_embedding_func:
- logger.warning('No embedding function found, falling back to WEIGHT method')
- kg_chunk_pick_method = 'WEIGHT'
+ logger.warning("No embedding function found, falling back to WEIGHT method")
+ kg_chunk_pick_method = "WEIGHT"
else:
try:
selected_chunk_ids = await pick_by_vector_similarity(
@@ -5385,63 +4673,93 @@ async def _find_related_text_unit_from_relations(
)
if selected_chunk_ids == []:
- kg_chunk_pick_method = 'WEIGHT'
+ kg_chunk_pick_method = "WEIGHT"
logger.warning(
- 'No relation-related chunks selected by vector similarity, falling back to WEIGHT method'
+ "No relation-related chunks selected by vector similarity, falling back to WEIGHT method"
)
else:
logger.info(
- f'Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity'
+ f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity"
)
except Exception as e:
- logger.error(f'Error in vector similarity sorting: {e}, falling back to WEIGHT method')
- kg_chunk_pick_method = 'WEIGHT'
+ logger.error(
+ f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
+ )
+ kg_chunk_pick_method = "WEIGHT"
- if kg_chunk_pick_method == 'WEIGHT':
+ if kg_chunk_pick_method == "WEIGHT":
# Apply linear gradient weighted polling algorithm
- selected_chunk_ids = pick_by_weighted_polling(relations_with_chunks, max_related_chunks, min_related_chunks=1)
+ selected_chunk_ids = pick_by_weighted_polling(
+ relations_with_chunks, max_related_chunks, min_related_chunks=1
+ )
logger.info(
- f'Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling'
+ f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling"
)
logger.debug(
- f'KG related chunks: {len(entity_chunks or [])} from entitys, {len(selected_chunk_ids)} from relations'
+ f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
)
if not selected_chunk_ids:
return []
# Step 5: Batch retrieve chunk data
- unique_chunk_ids = list(dict.fromkeys(selected_chunk_ids)) # Remove duplicates while preserving order
+ unique_chunk_ids = list(
+ dict.fromkeys(selected_chunk_ids)
+ ) # Remove duplicates while preserving order
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
# Step 6: Build result chunks with valid data and update chunk tracking
result_chunks = []
- for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list, strict=False)):
- if chunk_data is not None and 'content' in chunk_data:
+ for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
+ if chunk_data is not None and "content" in chunk_data:
chunk_data_copy = chunk_data.copy()
- chunk_data_copy['source_type'] = 'relationship'
- chunk_data_copy['chunk_id'] = chunk_id # Add chunk_id for deduplication
+ chunk_data_copy["source_type"] = "relationship"
+ chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
result_chunks.append(chunk_data_copy)
# Update chunk tracking if provided
if chunk_tracking is not None:
chunk_tracking[chunk_id] = {
- 'source': 'R',
- 'frequency': chunk_occurrence_count.get(chunk_id, 1),
- 'order': i + 1, # 1-based order in final relation-related results
+ "source": "R",
+ "frequency": chunk_occurrence_count.get(chunk_id, 1),
+ "order": i + 1, # 1-based order in final relation-related results
}
return result_chunks
+@overload
+async def naive_query(
+ query: str,
+ chunks_vdb: BaseVectorStorage,
+ query_param: QueryParam,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+ system_prompt: str | None = None,
+ return_raw_data: Literal[True] = True,
+) -> dict[str, Any]: ...
+
+
+@overload
+async def naive_query(
+ query: str,
+ chunks_vdb: BaseVectorStorage,
+ query_param: QueryParam,
+ global_config: dict[str, str],
+ hashing_kv: BaseKVStorage | None = None,
+ system_prompt: str | None = None,
+ return_raw_data: Literal[False] = False,
+) -> str | AsyncIterator[str]: ...
+
+
async def naive_query(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
- global_config: dict[str, Any],
+ global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
) -> QueryResult | None:
@@ -5467,64 +4785,65 @@ async def naive_query(
"""
if not query:
- return QueryResult(content=PROMPTS['fail_response'])
+ return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
else:
- use_model_func = global_config['llm_model_func']
+ use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
- llm_callable = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], use_model_func)
- tokenizer: Tokenizer = global_config['tokenizer']
+ tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
- logger.error('Tokenizer not found in global configuration.')
- return QueryResult(content=PROMPTS['fail_response'])
+ logger.error("Tokenizer not found in global configuration.")
+ return QueryResult(content=PROMPTS["fail_response"])
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
- logger.info('[naive_query] No relevant document chunks found; returning no-result.')
+ logger.info(
+ "[naive_query] No relevant document chunks found; returning no-result."
+ )
return None
# Calculate dynamic token limit for chunks
max_total_tokens = getattr(
query_param,
- 'max_total_tokens',
- global_config.get('max_total_tokens', DEFAULT_MAX_TOTAL_TOKENS),
+ "max_total_tokens",
+ global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
)
# Calculate system prompt template tokens (excluding content_data)
- user_prompt = f'\n\n{query_param.user_prompt}' if query_param.user_prompt else 'n/a'
- response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs'
+ user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a"
+ response_type = (
+ query_param.response_type
+ if query_param.response_type
+ else "Multiple Paragraphs"
+ )
# Use the provided system prompt or default
- sys_prompt_template = system_prompt if system_prompt else PROMPTS['naive_rag_response']
-
- # Detect sparse context for coverage guidance (adapted from KG mode)
- chunk_count = len(chunks) if chunks else 0
- is_sparse = chunk_count < 5 # Slightly higher threshold for naive mode
- coverage_guidance = (
- PROMPTS['coverage_guidance_limited'] if is_sparse else PROMPTS['coverage_guidance_good']
+ sys_prompt_template = (
+ system_prompt if system_prompt else PROMPTS["naive_rag_response"]
)
# Create a preliminary system prompt with empty content_data to calculate overhead
pre_sys_prompt = sys_prompt_template.format(
response_type=response_type,
user_prompt=user_prompt,
- content_data='', # Empty for overhead calculation
- coverage_guidance=coverage_guidance,
+ content_data="", # Empty for overhead calculation
)
# Calculate available tokens for chunks
sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt))
query_tokens = len(tokenizer.encode(query))
buffer_tokens = 200 # reserved for reference list and safety buffer
- available_chunk_tokens = max_total_tokens - (sys_prompt_tokens + query_tokens + buffer_tokens)
+ available_chunk_tokens = max_total_tokens - (
+ sys_prompt_tokens + query_tokens + buffer_tokens
+ )
logger.debug(
- f'Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}'
+ f"Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}"
)
# Process chunks using unified processing with dynamic token limit
@@ -5533,14 +4852,16 @@ async def naive_query(
unique_chunks=chunks,
query_param=query_param,
global_config=global_config,
- source_type='vector',
+ source_type="vector",
chunk_token_limit=available_chunk_tokens, # Pass dynamic limit
)
# Generate reference list from processed chunks using the new common function
- reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks(processed_chunks)
+ reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks(
+ processed_chunks
+ )
- logger.info(f'Final context: {len(processed_chunks_with_ref_ids)} chunks')
+ logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = convert_to_user_format(
@@ -5548,37 +4869,41 @@ async def naive_query(
[], # naive mode has no relationships
processed_chunks_with_ref_ids,
reference_list,
- 'naive',
+ "naive",
)
# Add complete metadata for naive mode
- if 'metadata' not in raw_data:
- raw_data['metadata'] = {}
- raw_data['metadata']['keywords'] = {
- 'high_level': [], # naive mode has no keyword extraction
- 'low_level': [], # naive mode has no keyword extraction
+ if "metadata" not in raw_data:
+ raw_data["metadata"] = {}
+ raw_data["metadata"]["keywords"] = {
+ "high_level": [], # naive mode has no keyword extraction
+ "low_level": [], # naive mode has no keyword extraction
}
- raw_data['metadata']['processing_info'] = {
- 'total_chunks_found': len(chunks),
- 'final_chunks_count': len(processed_chunks_with_ref_ids),
+ raw_data["metadata"]["processing_info"] = {
+ "total_chunks_found": len(chunks),
+ "final_chunks_count": len(processed_chunks_with_ref_ids),
}
# Build chunks_context from processed chunks with reference IDs
chunks_context = []
- for _i, chunk in enumerate(processed_chunks_with_ref_ids):
+ for i, chunk in enumerate(processed_chunks_with_ref_ids):
chunks_context.append(
{
- 'reference_id': chunk['reference_id'],
- 'content': chunk['content'],
+ "reference_id": chunk["reference_id"],
+ "content": chunk["content"],
}
)
- text_units_str = '\n'.join(json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context)
- reference_list_str = '\n'.join(
- f'[{ref["reference_id"]}] {ref["file_path"]}' for ref in reference_list if ref['reference_id']
+ text_units_str = "\n".join(
+ json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context
+ )
+ reference_list_str = "\n".join(
+ f"[{ref['reference_id']}] {ref['file_path']}"
+ for ref in reference_list
+ if ref["reference_id"]
)
- naive_context_template = PROMPTS['naive_query_context']
+ naive_context_template = PROMPTS["naive_query_context"]
context_content = naive_context_template.format(
text_chunks_str=text_units_str,
reference_list_str=reference_list_str,
@@ -5591,13 +4916,12 @@ async def naive_query(
response_type=query_param.response_type,
user_prompt=user_prompt,
content_data=context_content,
- coverage_guidance=coverage_guidance,
)
user_query = query
if query_param.only_need_prompt:
- prompt_content = '\n\n'.join([sys_prompt, '---User Query---', user_query])
+ prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data)
# Handle cache
@@ -5610,16 +4934,20 @@ async def naive_query(
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
- query_param.user_prompt or '',
+ query_param.user_prompt or "",
query_param.enable_rerank,
)
- cached_result = await handle_cache(hashing_kv, args_hash, user_query, query_param.mode, cache_type='query')
+ cached_result = await handle_cache(
+ hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
+ )
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
- logger.info(' == LLM cache == Query cache hit, using cached response as query result')
+ logger.info(
+ " == LLM cache == Query cache hit, using cached response as query result"
+ )
response = cached_response
else:
- response = await llm_callable(
+ response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
@@ -5627,17 +4955,17 @@ async def naive_query(
stream=query_param.stream,
)
- if isinstance(response, str) and hashing_kv and hashing_kv.global_config.get('enable_llm_cache'):
+ if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
- 'mode': query_param.mode,
- 'response_type': query_param.response_type,
- 'top_k': query_param.top_k,
- 'chunk_top_k': query_param.chunk_top_k,
- 'max_entity_tokens': query_param.max_entity_tokens,
- 'max_relation_tokens': query_param.max_relation_tokens,
- 'max_total_tokens': query_param.max_total_tokens,
- 'user_prompt': query_param.user_prompt or '',
- 'enable_rerank': query_param.enable_rerank,
+ "mode": query_param.mode,
+ "response_type": query_param.response_type,
+ "top_k": query_param.top_k,
+ "chunk_top_k": query_param.chunk_top_k,
+ "max_entity_tokens": query_param.max_entity_tokens,
+ "max_relation_tokens": query_param.max_relation_tokens,
+ "max_total_tokens": query_param.max_total_tokens,
+ "user_prompt": query_param.user_prompt or "",
+ "enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
@@ -5646,7 +4974,7 @@ async def naive_query(
content=response,
prompt=query,
mode=query_param.mode,
- cache_type='query',
+ cache_type="query",
queryparam=queryparam_dict,
),
)
@@ -5657,16 +4985,18 @@ async def naive_query(
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
- .replace(sys_prompt, '')
- .replace('user', '')
- .replace('model', '')
- .replace(query, '')
- .replace('', '')
- .replace('', '')
+ .replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
.strip()
)
return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)
- return QueryResult(response_iterator=response, raw_data=raw_data, is_streaming=True)
+ return QueryResult(
+ response_iterator=response, raw_data=raw_data, is_streaming=True
+ )
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index d36a2244..dcd829d4 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -1,260 +1,348 @@
from __future__ import annotations
-
from typing import Any
+
PROMPTS: dict[str, Any] = {}
-# All delimiters must be formatted as "<|TOKEN|>" style markers (e.g., "<|#|>" or "<|COMPLETE|>")
-PROMPTS['DEFAULT_TUPLE_DELIMITER'] = '<|#|>'
-PROMPTS['DEFAULT_COMPLETION_DELIMITER'] = '<|COMPLETE|>'
+# All delimiters must be formatted as "<|UPPER_CASE_STRING|>"
+PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|#|>"
+PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
-PROMPTS['entity_extraction_system_prompt'] = """---Role---
-You are a Knowledge Graph Specialist extracting entities and relationships from text.
+PROMPTS["entity_extraction_system_prompt"] = """---Role---
+You are a Knowledge Graph Specialist responsible for extracting entities and relationships from the input text.
----Output Format---
-Output raw lines only—NO markdown, NO headers, NO backticks.
+---Instructions---
+1. **Entity Extraction & Output:**
+ * **Identification:** Identify clearly defined and meaningful entities in the input text.
+ * **Entity Details:** For each identified entity, extract the following information:
+ * `entity_name`: The name of the entity. If the entity name is case-insensitive, capitalize the first letter of each significant word (title case). Ensure **consistent naming** across the entire extraction process.
+ * `entity_type`: Categorize the entity using one of the following types: `{entity_types}`. If none of the provided entity types apply, do not add new entity type and classify it as `Other`.
+ * `entity_description`: Provide a concise yet comprehensive description of the entity's attributes and activities, based *solely* on the information present in the input text.
+ * **Output Format - Entities:** Output a total of 4 fields for each entity, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `entity`.
+ * Format: `entity{tuple_delimiter}entity_name{tuple_delimiter}entity_type{tuple_delimiter}entity_description`
-Entity: entity{tuple_delimiter}name{tuple_delimiter}type{tuple_delimiter}description
-Relation: relation{tuple_delimiter}source{tuple_delimiter}target{tuple_delimiter}keywords{tuple_delimiter}description
+2. **Relationship Extraction & Output:**
+ * **Identification:** Identify direct, clearly stated, and meaningful relationships between previously extracted entities.
+ * **N-ary Relationship Decomposition:** If a single statement describes a relationship involving more than two entities (an N-ary relationship), decompose it into multiple binary (two-entity) relationship pairs for separate description.
+ * **Example:** For "Alice, Bob, and Carol collaborated on Project X," extract binary relationships such as "Alice collaborated with Project X," "Bob collaborated with Project X," and "Carol collaborated with Project X," or "Alice collaborated with Bob," based on the most reasonable binary interpretations.
+ * **Relationship Details:** For each binary relationship, extract the following fields:
+ * `source_entity`: The name of the source entity. Ensure **consistent naming** with entity extraction. Capitalize the first letter of each significant word (title case) if the name is case-insensitive.
+ * `target_entity`: The name of the target entity. Ensure **consistent naming** with entity extraction. Capitalize the first letter of each significant word (title case) if the name is case-insensitive.
+ * `relationship_keywords`: One or more high-level keywords summarizing the overarching nature, concepts, or themes of the relationship. Multiple keywords within this field must be separated by a comma `,`. **DO NOT use `{tuple_delimiter}` for separating multiple keywords within this field.**
+ * `relationship_description`: A concise explanation of the nature of the relationship between the source and target entities, providing a clear rationale for their connection.
+ * **Output Format - Relationships:** Output a total of 5 fields for each relationship, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `relation`.
+ * Format: `relation{tuple_delimiter}source_entity{tuple_delimiter}target_entity{tuple_delimiter}relationship_keywords{tuple_delimiter}relationship_description`
-Use Title Case for names. Separate keywords with commas. Output entities first, then relations. End with {completion_delimiter}.
+3. **Delimiter Usage Protocol:**
+ * The `{tuple_delimiter}` is a complete, atomic marker and **must not be filled with content**. It serves strictly as a field separator.
+ * **Incorrect Example:** `entity{tuple_delimiter}Tokyo<|location|>Tokyo is the capital of Japan.`
+ * **Correct Example:** `entity{tuple_delimiter}Tokyo{tuple_delimiter}location{tuple_delimiter}Tokyo is the capital of Japan.`
----Entity Extraction---
-Extract BOTH concrete and abstract entities:
-- **Concrete:** Named people, organizations, places, products, dates
-- **Abstract:** Concepts, events, categories, processes mentioned in text (e.g., "market selloff", "merger", "pandemic")
+4. **Relationship Direction & Duplication:**
+ * Treat all relationships as **undirected** unless explicitly stated otherwise. Swapping the source and target entities for an undirected relationship does not constitute a new relationship.
+ * Avoid outputting duplicate relationships.
-Types: `{entity_types}` (use `Other` if none fit)
+5. **Output Order & Prioritization:**
+ * Output all extracted entities first, followed by all extracted relationships.
+ * Within the list of relationships, prioritize and output those relationships that are **most significant** to the core meaning of the input text first.
----Relationship Extraction---
-Extract meaningful relationships:
-- **Direct:** explicit interactions, actions, connections
-- **Categorical:** entities sharing group membership or classification
-- **Causal:** cause-effect relationships
-- **Hierarchical:** part-of, member-of, type-of
+6. **Context & Objectivity:**
+ * Ensure all entity names and descriptions are written in the **third person**.
+ * Explicitly name the subject or object; **avoid using pronouns** such as `this article`, `this paper`, `our company`, `I`, `you`, and `he/she`.
-Create intermediate concept entities when they help connect related items (e.g., "Vaccines" connecting Pfizer/Moderna/AstraZeneca).
+7. **Language & Proper Nouns:**
+ * The entire output (entity names, keywords, and descriptions) must be written in `{language}`.
+ * Proper nouns (e.g., personal names, place names, organization names) should be retained in their original language if a proper, widely accepted translation is not available or would cause ambiguity.
-For N-ary relationships, decompose into binary pairs. Avoid duplicates.
-
----Guidelines---
-- Third person only; no pronouns like "this article", "I", "you"
-- Output in `{language}`. Keep proper nouns in original language.
+8. **Completion Signal:** Output the literal string `{completion_delimiter}` only after all entities and relationships, following all criteria, have been completely extracted and outputted.
---Examples---
{examples}
+"""
----Input---
-Entity_types: [{entity_types}]
-Text:
+PROMPTS["entity_extraction_user_prompt"] = """---Task---
+Extract entities and relationships from the input text in Data to be Processed below.
+
+---Instructions---
+1. **Strict Adherence to Format:** Strictly adhere to all format requirements for entity and relationship lists, including output order, field delimiters, and proper noun handling, as specified in the system prompt.
+2. **Output Content Only:** Output *only* the extracted list of entities and relationships. Do not include any introductory or concluding remarks, explanations, or additional text before or after the list.
+3. **Completion Signal:** Output `{completion_delimiter}` as the final line after all relevant entities and relationships have been extracted and presented.
+4. **Output Language:** Ensure the output language is {language}. Proper nouns (e.g., personal names, place names, organization names) must be kept in their original language and not translated.
+
+---Data to be Processed---
+
+[{entity_types}]
+
+
```
{input_text}
```
-"""
-
-PROMPTS['entity_extraction_user_prompt'] = """---Task---
-Extract entities and relationships from the text. Include both concrete entities AND abstract concepts/events.
-
-Follow format exactly. Output only extractions—no explanations. End with `{completion_delimiter}`.
-Output in {language}; keep proper nouns in original language.