diff --git a/lightrag/operate.py b/lightrag/operate.py index fb14277f..faab7f26 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,317 +1,83 @@ from __future__ import annotations - -import asyncio -import hashlib -import json -import os -import re -import time -from collections import Counter, defaultdict -from collections.abc import AsyncIterator, Awaitable, Callable from functools import partial from pathlib import Path -from typing import Any, cast +import asyncio +import json import json_repair -from dotenv import load_dotenv +from typing import Any, AsyncIterator, overload, Literal +from collections import Counter, defaultdict +from lightrag.exceptions import ( + PipelineCancelledException, + ChunkTokenLimitExceededError, +) +from lightrag.utils import ( + logger, + compute_mdhash_id, + Tokenizer, + is_float_regex, + sanitize_and_normalize_extracted_text, + pack_user_ass_to_openai_messages, + split_string_by_multi_markers, + truncate_list_by_token_size, + compute_args_hash, + handle_cache, + save_to_cache, + CacheData, + use_llm_func_with_cache, + update_chunk_cache_list, + remove_think_tags, + pick_by_weighted_polling, + pick_by_vector_similarity, + process_chunks_unified, + safe_vdb_operation_with_exception, + create_prefixed_exception, + fix_tuple_delimiter_corruption, + convert_to_user_format, + generate_reference_list_from_chunks, + apply_source_ids_limit, + merge_source_ids, + make_relation_chunk_key, +) from lightrag.base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, - QueryContextResult, + TextChunkSchema, QueryParam, QueryResult, - TextChunkSchema, + QueryContextResult, ) +from lightrag.prompt import PROMPTS from lightrag.constants import ( - DEFAULT_ENTITY_NAME_MAX_LENGTH, - DEFAULT_ENTITY_TYPES, - DEFAULT_FILE_PATH_MORE_PLACEHOLDER, - DEFAULT_KG_CHUNK_PICK_METHOD, + GRAPH_FIELD_SEP, DEFAULT_MAX_ENTITY_TOKENS, - DEFAULT_MAX_FILE_PATHS, DEFAULT_MAX_RELATION_TOKENS, - DEFAULT_MAX_SOURCE_IDS_PER_ENTITY, - DEFAULT_MAX_SOURCE_IDS_PER_RELATION, DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_RELATED_CHUNK_NUMBER, - DEFAULT_RETRIEVAL_MULTIPLIER, - DEFAULT_SOURCE_IDS_LIMIT_METHOD, + DEFAULT_KG_CHUNK_PICK_METHOD, + DEFAULT_ENTITY_TYPES, DEFAULT_SUMMARY_LANGUAGE, - GRAPH_FIELD_SEP, - SOURCE_IDS_LIMIT_METHOD_FIFO, SOURCE_IDS_LIMIT_METHOD_KEEP, -) -from lightrag.entity_resolution import ( - EntityResolutionConfig, - fuzzy_similarity, - get_cached_alias, - resolve_entity_with_vdb, - store_alias, -) -from lightrag.exceptions import ( - ChunkTokenLimitExceededError, - PipelineCancelledException, + SOURCE_IDS_LIMIT_METHOD_FIFO, + DEFAULT_FILE_PATH_MORE_PLACEHOLDER, + DEFAULT_MAX_FILE_PATHS, + DEFAULT_ENTITY_NAME_MAX_LENGTH, ) from lightrag.kg.shared_storage import get_storage_keyed_lock -from lightrag.prompt import PROMPTS -from lightrag.utils import ( - CacheData, - Tokenizer, - apply_source_ids_limit, - compute_args_hash, - compute_mdhash_id, - convert_to_user_format, - create_prefixed_exception, - fix_tuple_delimiter_corruption, - generate_reference_list_from_chunks, - handle_cache, - is_float_regex, - logger, - make_relation_chunk_key, - merge_source_ids, - pack_user_ass_to_openai_messages, - pick_by_vector_similarity, - pick_by_weighted_polling, - process_chunks_unified, - remove_think_tags, - safe_vdb_operation_with_exception, - sanitize_and_normalize_extracted_text, - save_to_cache, - split_string_by_multi_markers, - truncate_list_by_token_size, - update_chunk_cache_list, - use_llm_func_with_cache, -) - -# Query embedding cache configuration (configurable via environment variables) -QUERY_EMBEDDING_CACHE_TTL = int(os.getenv('QUERY_EMBEDDING_CACHE_TTL', '3600')) # 1 hour -QUERY_EMBEDDING_CACHE_MAX_SIZE = int( - os.getenv('QUERY_EMBEDDING_CACHE_MAX_SIZE', os.getenv('QUERY_EMBEDDING_CACHE_SIZE', '10000')) -) - -# Redis cache configuration -REDIS_EMBEDDING_CACHE_ENABLED = os.getenv('REDIS_EMBEDDING_CACHE', 'false').lower() == 'true' -REDIS_URI = os.getenv('REDIS_URI', 'redis://localhost:6379') - -# Local in-memory cache with LRU eviction -# Structure: {query_hash: (embedding, timestamp)} -_query_embedding_cache: dict[str, tuple[list[float], float]] = {} -_query_embedding_cache_locks: dict[int, asyncio.Lock] = {} - - -def _get_query_embedding_cache_lock() -> asyncio.Lock: - """Return an event-loop-local lock for the embedding cache.""" - loop = asyncio.get_running_loop() - lock = _query_embedding_cache_locks.get(id(loop)) - if lock is None: - lock = asyncio.Lock() - _query_embedding_cache_locks[id(loop)] = lock - return lock - - -# Global Redis client (lazy initialized) -_redis_client = None - - -async def _get_redis_client(): - """Lazy initialize Redis client.""" - global _redis_client - if _redis_client is None and REDIS_EMBEDDING_CACHE_ENABLED: - try: - import redis.asyncio as redis - - _redis_client = redis.from_url(REDIS_URI, decode_responses=True) - # Test connection - await _redis_client.ping() - logger.info(f'Redis embedding cache connected: {REDIS_URI}') - except ImportError: - logger.warning('Redis package not installed. Install with: pip install redis') - return None - except Exception as e: - logger.warning(f'Failed to connect to Redis: {e}. Falling back to local cache.') - return None - return _redis_client - - -async def get_cached_query_embedding(query: str, embedding_func) -> list[float] | None: - """Get query embedding with caching to avoid redundant API calls. - - Supports both local in-memory cache and Redis for cross-worker sharing. - Redis is used when REDIS_EMBEDDING_CACHE=true environment variable is set. - - Args: - query: The query string to embed - embedding_func: The embedding function to call on cache miss - - Returns: - The embedding vector, or None if embedding fails - """ - query_hash = hashlib.sha256(query.encode()).hexdigest()[:16] - current_time = time.time() - redis_key = f'lightrag:emb:{query_hash}' - - # Try Redis cache first (if enabled) - if REDIS_EMBEDDING_CACHE_ENABLED: - try: - redis_client = await _get_redis_client() - if redis_client: - cached_json = await redis_client.get(redis_key) - if cached_json: - embedding = json.loads(cached_json) - logger.debug(f'Redis embedding cache hit for hash {query_hash[:8]}') - # Also update local cache - _query_embedding_cache[query_hash] = (embedding, current_time) - return embedding - except Exception as e: - logger.debug(f'Redis cache read error: {e}') - - # Check local cache - cached = _query_embedding_cache.get(query_hash) - if cached and (current_time - cached[1]) < QUERY_EMBEDDING_CACHE_TTL: - logger.debug(f'Local embedding cache hit for hash {query_hash[:8]}') - return cached[0] - - # Cache miss - compute embedding - try: - embedding = await embedding_func([query]) - embedding_result = embedding[0] # Extract first from batch - - # Manage local cache size - LRU eviction of oldest entries - async with _get_query_embedding_cache_lock(): - if len(_query_embedding_cache) >= QUERY_EMBEDDING_CACHE_MAX_SIZE: - # Remove oldest 10% of entries - sorted_entries = sorted(_query_embedding_cache.items(), key=lambda x: x[1][1]) - for old_key, _ in sorted_entries[: QUERY_EMBEDDING_CACHE_MAX_SIZE // 10]: - del _query_embedding_cache[old_key] - - # Store in local cache - _query_embedding_cache[query_hash] = (embedding_result, current_time) - - # Store in Redis (if enabled) - if REDIS_EMBEDDING_CACHE_ENABLED: - try: - redis_client = await _get_redis_client() - if redis_client: - await redis_client.setex( - redis_key, - QUERY_EMBEDDING_CACHE_TTL, - json.dumps(embedding_result), - ) - logger.debug(f'Embedding cached in Redis for hash {query_hash[:8]}') - except Exception as e: - logger.debug(f'Redis cache write error: {e}') - - logger.debug(f'Query embedding computed and cached for hash {query_hash[:8]}') - return embedding_result - except Exception as e: - logger.warning(f'Failed to compute query embedding: {e}') - return None - +import time +from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file -load_dotenv(dotenv_path=Path(__file__).resolve().parent / '.env', override=False) +load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=False) -class TimingContext: - """Context manager for collecting timing information during query execution. - - Provides both synchronous and context manager interfaces for measuring - execution time of different query stages. - - Example: - timing = TimingContext() - timing.start('keyword_extraction') - # ... do work ... - timing.stop('keyword_extraction') - - # Or with context manager: - with timing.measure('vector_search'): - # ... do work ... - - print(timing.timings) # {'keyword_extraction': 123.4, 'vector_search': 567.8} - """ - - def __init__(self) -> None: - """Initialize timing context.""" - self.timings: dict[str, float] = {} - self._start_times: dict[str, float] = {} - self._total_start: float = time.perf_counter() - - def start(self, stage: str) -> None: - """Start timing a stage. - - Args: - stage: Name of the stage to time - """ - self._start_times[stage] = time.perf_counter() - - def stop(self, stage: str) -> float: - """Stop timing a stage and record the duration. - - Args: - stage: Name of the stage to stop - - Returns: - Duration in milliseconds - """ - if stage in self._start_times: - elapsed = (time.perf_counter() - self._start_times[stage]) * 1000 - self.timings[stage] = elapsed - del self._start_times[stage] - return elapsed - return 0.0 - - def measure(self, stage: str): - """Context manager for timing a stage. - - Args: - stage: Name of the stage to time - - Yields: - None - """ - return _TimingContextManager(self, stage) - - def get_total_ms(self) -> float: - """Get total elapsed time since context creation. - - Returns: - Total time in milliseconds - """ - return (time.perf_counter() - self._total_start) * 1000 - - def finalize(self) -> dict[str, float]: - """Finalize timing and return all recorded durations. - - Sets the 'total' timing to elapsed time since context creation. - - Returns: - Dictionary of stage -> duration_ms - """ - self.timings['total'] = self.get_total_ms() - return self.timings - - -class _TimingContextManager: - """Internal context manager for TimingContext.measure().""" - - def __init__(self, timing_ctx: TimingContext, stage: str) -> None: - self._timing_ctx = timing_ctx - self._stage = stage - - def __enter__(self) -> None: - self._timing_ctx.start(self._stage) - return None - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self._timing_ctx.stop(self._stage) - - -def _truncate_entity_identifier(identifier: str, limit: int, chunk_key: str, identifier_role: str) -> str: +def _truncate_entity_identifier( + identifier: str, limit: int, chunk_key: str, identifier_role: str +) -> str: """Truncate entity identifiers that exceed the configured length limit.""" if len(identifier) <= limit: @@ -338,25 +104,17 @@ def chunking_by_token_size( chunk_overlap_token_size: int = 100, chunk_token_size: int = 1200, ) -> list[dict[str, Any]]: - """Split content into chunks by token size, tracking character positions. - - Returns chunks with char_start and char_end for citation support. - """ tokens = tokenizer.encode(content) results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) - # Track character positions: (tokens, chunk_text, char_start, char_end) - new_chunks: list[tuple[int, str, int, int]] = [] - char_position = 0 - separator_len = len(split_by_character) - + new_chunks = [] if split_by_character_only: for chunk in raw_chunks: _tokens = tokenizer.encode(chunk) if len(_tokens) > chunk_token_size: logger.warning( - 'Chunk split_by_character exceeds token limit: len=%d limit=%d', + "Chunk split_by_character exceeds token limit: len=%d limit=%d", len(_tokens), chunk_token_size, ) @@ -365,58 +123,40 @@ def chunking_by_token_size( chunk_token_limit=chunk_token_size, chunk_preview=chunk[:120], ) - chunk_start = char_position - chunk_end = char_position + len(chunk) - new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end)) - char_position = chunk_end + separator_len # Skip separator + new_chunks.append((len(_tokens), chunk)) else: for chunk in raw_chunks: - chunk_start = char_position _tokens = tokenizer.encode(chunk) if len(_tokens) > chunk_token_size: - # Sub-chunking: approximate char positions within the chunk - sub_char_position = 0 - for start in range(0, len(_tokens), chunk_token_size - chunk_overlap_token_size): - chunk_content = tokenizer.decode(_tokens[start : start + chunk_token_size]) - # Approximate char position based on content length ratio - sub_start = chunk_start + sub_char_position - sub_end = sub_start + len(chunk_content) - new_chunks.append( - (min(chunk_token_size, len(_tokens) - start), chunk_content, sub_start, sub_end) + for start in range( + 0, len(_tokens), chunk_token_size - chunk_overlap_token_size + ): + chunk_content = tokenizer.decode( + _tokens[start : start + chunk_token_size] + ) + new_chunks.append( + (min(chunk_token_size, len(_tokens) - start), chunk_content) ) - sub_char_position += len(chunk_content) - (chunk_overlap_token_size * 4) # Approx overlap else: - chunk_end = chunk_start + len(chunk) - new_chunks.append((len(_tokens), chunk, chunk_start, chunk_end)) - char_position = chunk_start + len(chunk) + separator_len - - for index, (_len, chunk, char_start, char_end) in enumerate(new_chunks): + new_chunks.append((len(_tokens), chunk)) + for index, (_len, chunk) in enumerate(new_chunks): results.append( { - 'tokens': _len, - 'content': chunk.strip(), - 'chunk_order_index': index, - 'char_start': char_start, - 'char_end': char_end, + "tokens": _len, + "content": chunk.strip(), + "chunk_order_index": index, } ) else: - # Token-based chunking: track character positions through decoded content - char_position = 0 - for index, start in enumerate(range(0, len(tokens), chunk_token_size - chunk_overlap_token_size)): + for index, start in enumerate( + range(0, len(tokens), chunk_token_size - chunk_overlap_token_size) + ): chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size]) - # For overlapping chunks, approximate positions based on previous chunk - char_start = 0 if index == 0 else char_position - char_end = char_start + len(chunk_content) - char_position = char_start + len(chunk_content) - (chunk_overlap_token_size * 4) # Approx char overlap - results.append( { - 'tokens': min(chunk_token_size, len(tokens) - start), - 'content': chunk_content.strip(), - 'chunk_order_index': index, - 'char_start': char_start, - 'char_end': char_end, + "tokens": min(chunk_token_size, len(tokens) - start), + "content": chunk_content.strip(), + "chunk_order_index": index, } ) return results @@ -450,17 +190,17 @@ async def _handle_entity_relation_summary( """ # Handle empty input if not description_list: - return '', False + return "", False # If only one description, return it directly (no need for LLM call) if len(description_list) == 1: return description_list[0], False # Get configuration - tokenizer: Tokenizer = global_config['tokenizer'] - summary_context_size = global_config['summary_context_size'] - summary_max_tokens = global_config['summary_max_tokens'] - force_llm_summary_on_merge = global_config['force_llm_summary_on_merge'] + tokenizer: Tokenizer = global_config["tokenizer"] + summary_context_size = global_config["summary_context_size"] + summary_max_tokens = global_config["summary_max_tokens"] + force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] current_list = description_list[:] # Copy the list to avoid modifying original llm_was_used = False # Track whether LLM was used during the entire process @@ -472,13 +212,18 @@ async def _handle_entity_relation_summary( # If total length is within limits, perform final summarization if total_tokens <= summary_context_size or len(current_list) <= 2: - if len(current_list) < force_llm_summary_on_merge and total_tokens < summary_max_tokens: + if ( + len(current_list) < force_llm_summary_on_merge + and total_tokens < summary_max_tokens + ): # no LLM needed, just join the descriptions final_description = seperator.join(current_list) - return final_description if final_description else '', llm_was_used + return final_description if final_description else "", llm_was_used else: if total_tokens > summary_context_size and len(current_list) <= 2: - logger.warning(f'Summarizing {entity_or_relation_name}: Oversize description found') + logger.warning( + f"Summarizing {entity_or_relation_name}: Oversize descpriton found" + ) # Final summarization of remaining descriptions - LLM will be used final_summary = await _summarize_descriptions( description_type, @@ -496,7 +241,7 @@ async def _handle_entity_relation_summary( current_tokens = 0 # Currently least 3 descriptions in current_list - for _i, desc in enumerate(current_list): + for i, desc in enumerate(current_list): desc_tokens = len(tokenizer.encode(desc)) # If adding current description would exceed limit, finalize current chunk @@ -506,7 +251,9 @@ async def _handle_entity_relation_summary( # Force add one more description to ensure minimum 2 per chunk current_chunk.append(desc) chunks.append(current_chunk) - logger.warning(f'Summarizing {entity_or_relation_name}: Oversize description found') + logger.warning( + f"Summarizing {entity_or_relation_name}: Oversize descpriton found" + ) current_chunk = [] # next group is empty current_tokens = 0 else: # curren_chunk is ready for summary in reduce phase @@ -522,15 +269,15 @@ async def _handle_entity_relation_summary( chunks.append(current_chunk) logger.info( - f' Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups' + f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups" ) - # Reduce phase: summarize each group from chunks IN PARALLEL - async def _summarize_single_chunk(chunk: list[str]) -> tuple[str, bool]: - """Summarize a single chunk, returning (summary, used_llm).""" + # Reduce phase: summarize each group from chunks + new_summaries = [] + for chunk in chunks: if len(chunk) == 1: # Optimization: single description chunks don't need LLM summarization - return chunk[0], False + new_summaries.append(chunk[0]) else: # Multiple descriptions need LLM summarization summary = await _summarize_descriptions( @@ -540,16 +287,8 @@ async def _handle_entity_relation_summary( global_config, llm_response_cache, ) - return summary, True - - # Create tasks for all chunks and run in parallel - tasks = [asyncio.create_task(_summarize_single_chunk(chunk)) for chunk in chunks] - results = await asyncio.gather(*tasks) - - # Collect results while preserving order - new_summaries = [result[0] for result in results] - if any(result[1] for result in results): - llm_was_used = True # Mark that LLM was used in reduce phase + new_summaries.append(summary) + llm_was_used = True # Mark that LLM was used in reduce phase # Update current list with new summaries for next iteration current_list = new_summaries @@ -573,22 +312,22 @@ async def _summarize_descriptions( Returns: Summarized description string """ - use_llm_func: Callable[..., Any] = global_config['llm_model_func'] + use_llm_func: callable = global_config["llm_model_func"] # Apply higher priority (8) to entity/relation summary tasks use_llm_func = partial(use_llm_func, _priority=8) - language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE) + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) - summary_length_recommended = global_config['summary_length_recommended'] + summary_length_recommended = global_config["summary_length_recommended"] - prompt_template = PROMPTS['summarize_entity_descriptions'] + prompt_template = PROMPTS["summarize_entity_descriptions"] # Convert descriptions to JSONL format and apply token-based truncation - tokenizer = global_config['tokenizer'] - summary_context_size = global_config['summary_context_size'] + tokenizer = global_config["tokenizer"] + summary_context_size = global_config["summary_context_size"] # Create list of JSON objects with "Description" field - json_descriptions = [{'Description': desc} for desc in description_list] + json_descriptions = [{"Description": desc} for desc in description_list] # Use truncate_list_by_token_size for length truncation truncated_json_descriptions = truncate_list_by_token_size( @@ -599,16 +338,18 @@ async def _summarize_descriptions( ) # Convert to JSONL format (one JSON object per line) - joined_descriptions = '\n'.join(json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions) + joined_descriptions = "\n".join( + json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions + ) # Prepare context for the prompt - context_base = { - 'description_type': description_type, - 'description_name': description_name, - 'description_list': joined_descriptions, - 'summary_length': summary_length_recommended, - 'language': language, - } + context_base = dict( + description_type=description_type, + description_name=description_name, + description_list=joined_descriptions, + summary_length=summary_length_recommended, + language=language, + ) use_prompt = prompt_template.format(**context_base) # Use LLM function with cache (higher priority for summary generation) @@ -616,20 +357,20 @@ async def _summarize_descriptions( use_prompt, use_llm_func, llm_response_cache=llm_response_cache, - cache_type='summary', + cache_type="summary", ) # Check summary token length against embedding limit - embedding_token_limit = global_config.get('embedding_token_limit') + embedding_token_limit = global_config.get("embedding_token_limit") if embedding_token_limit is not None and summary: - tokenizer = global_config['tokenizer'] + tokenizer = global_config["tokenizer"] summary_token_count = len(tokenizer.encode(summary)) threshold = int(embedding_token_limit * 0.9) if summary_token_count > threshold: logger.warning( - f'Summary tokens ({summary_token_count}) exceeds 90% of embedding limit ' - f'({embedding_token_limit}) for {description_type}: {description_name}' + f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit " + f"({embedding_token_limit}) for {description_type}: {description_name}" ) return summary @@ -639,33 +380,43 @@ async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, timestamp: int, - file_path: str = 'unknown_source', + file_path: str = "unknown_source", ): - if len(record_attributes) != 4 or 'entity' not in record_attributes[0]: - if len(record_attributes) > 1 and 'entity' in record_attributes[0]: + if len(record_attributes) != 4 or "entity" not in record_attributes[0]: + if len(record_attributes) > 1 and "entity" in record_attributes[0]: logger.warning( - f'{chunk_key}: LLM output format error; found {len(record_attributes)}/4 fields on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else "N/A"}`' + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/4 feilds on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" ) logger.debug(record_attributes) return None try: - entity_name = sanitize_and_normalize_extracted_text(record_attributes[1], remove_inner_quotes=True) + entity_name = sanitize_and_normalize_extracted_text( + record_attributes[1], remove_inner_quotes=True + ) # Validate entity name after all cleaning steps if not entity_name or not entity_name.strip(): - logger.info(f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'") + logger.info( + f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'" + ) return None # Process entity type with same cleaning pipeline - entity_type = sanitize_and_normalize_extracted_text(record_attributes[2], remove_inner_quotes=True) + entity_type = sanitize_and_normalize_extracted_text( + record_attributes[2], remove_inner_quotes=True + ) - if not entity_type.strip() or any(char in entity_type for char in ["'", '(', ')', '<', '>', '|', '/', '\\']): - logger.warning(f'Entity extraction error: invalid entity type in: {record_attributes}') + if not entity_type.strip() or any( + char in entity_type for char in ["'", "(", ")", "<", ">", "|", "/", "\\"] + ): + logger.warning( + f"Entity extraction error: invalid entity type in: {record_attributes}" + ) return None # Remove spaces and convert to lowercase - entity_type = entity_type.replace(' ', '').lower() + entity_type = entity_type.replace(" ", "").lower() # Process entity description with same cleaning pipeline entity_description = sanitize_and_normalize_extracted_text(record_attributes[3]) @@ -676,20 +427,24 @@ async def _handle_single_entity_extraction( ) return None - return { - 'entity_name': entity_name, - 'entity_type': entity_type, - 'description': entity_description, - 'source_id': chunk_key, - 'file_path': file_path, - 'timestamp': timestamp, - } + return dict( + entity_name=entity_name, + entity_type=entity_type, + description=entity_description, + source_id=chunk_key, + file_path=file_path, + timestamp=timestamp, + ) except ValueError as e: - logger.error(f'Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}') + logger.error( + f"Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}" + ) return None except Exception as e: - logger.error(f'Entity extraction failed with unexpected error in chunk {chunk_key}: {e}') + logger.error( + f"Entity extraction failed with unexpected error in chunk {chunk_key}: {e}" + ) return None @@ -697,41 +452,50 @@ async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, timestamp: int, - file_path: str = 'unknown_source', + file_path: str = "unknown_source", ): if ( - len(record_attributes) != 5 or 'relation' not in record_attributes[0] + len(record_attributes) != 5 or "relation" not in record_attributes[0] ): # treat "relationship" and "relation" interchangeable - if len(record_attributes) > 1 and 'relation' in record_attributes[0]: + if len(record_attributes) > 1 and "relation" in record_attributes[0]: logger.warning( - f'{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on RELATION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else "N/A"}`' + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" ) logger.debug(record_attributes) return None try: - source = sanitize_and_normalize_extracted_text(record_attributes[1], remove_inner_quotes=True) - target = sanitize_and_normalize_extracted_text(record_attributes[2], remove_inner_quotes=True) + source = sanitize_and_normalize_extracted_text( + record_attributes[1], remove_inner_quotes=True + ) + target = sanitize_and_normalize_extracted_text( + record_attributes[2], remove_inner_quotes=True + ) # Validate entity names after all cleaning steps if not source: - logger.info(f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'") + logger.info( + f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'" + ) return None if not target: - logger.info(f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'") + logger.info( + f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'" + ) return None if source == target: - logger.debug(f'Relationship source and target are the same in: {record_attributes}') + logger.debug( + f"Relationship source and target are the same in: {record_attributes}" + ) return None # Process keywords with same cleaning pipeline - edge_keywords = sanitize_and_normalize_extracted_text(record_attributes[3], remove_inner_quotes=True) - edge_keywords = edge_keywords.replace(',', ',') - - # Derive a relationship label from the first keyword (fallback to description later) - relationship_label = edge_keywords.split(',')[0].strip() if edge_keywords else '' + edge_keywords = sanitize_and_normalize_extracted_text( + record_attributes[3], remove_inner_quotes=True + ) + edge_keywords = edge_keywords.replace(",", ",") # Process relationship description with same cleaning pipeline edge_description = sanitize_and_normalize_extracted_text(record_attributes[4]) @@ -743,24 +507,26 @@ async def _handle_single_relationship_extraction( else 1.0 ) - return { - 'src_id': source, - 'tgt_id': target, - 'weight': weight, - 'description': edge_description, - 'keywords': edge_keywords, - 'relationship': relationship_label, - 'type': relationship_label, - 'source_id': edge_source_id, - 'file_path': file_path, - 'timestamp': timestamp, - } + return dict( + src_id=source, + tgt_id=target, + weight=weight, + description=edge_description, + keywords=edge_keywords, + source_id=edge_source_id, + file_path=file_path, + timestamp=timestamp, + ) except ValueError as e: - logger.warning(f'Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}') + logger.warning( + f"Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}" + ) return None except Exception as e: - logger.warning(f'Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}') + logger.warning( + f"Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}" + ) return None @@ -772,7 +538,7 @@ async def rebuild_knowledge_from_chunks( relationships_vdb: BaseVectorStorage, text_chunks_storage: BaseKVStorage, llm_response_cache: BaseKVStorage, - global_config: dict[str, Any], + global_config: dict[str, str], pipeline_status: dict | None = None, pipeline_status_lock=None, entity_chunks_storage: BaseKVStorage | None = None, @@ -808,14 +574,12 @@ async def rebuild_knowledge_from_chunks( for chunk_ids in relationships_to_rebuild.values(): all_referenced_chunk_ids.update(chunk_ids) - status_message = ( - f'Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)' - ) + status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) # Get cached extraction results for these chunks using storage # cached_results: chunk_id -> [list of (extraction_result, create_time) from LLM cache sorted by create_time of the first extraction_result] @@ -826,12 +590,12 @@ async def rebuild_knowledge_from_chunks( ) if not cached_results: - status_message = 'No cached extraction results found, cannot rebuild' + status_message = "No cached extraction results found, cannot rebuild" logger.warning(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) return # Process cached results to get entities and relationships for each chunk @@ -850,7 +614,7 @@ async def rebuild_knowledge_from_chunks( text_chunks_storage=text_chunks_storage, chunk_id=chunk_id, extraction_result=result[0], - timestamp=int(result[1]) if result[1] is not None else int(time.time()), + timestamp=result[1], ) # Merge entities and relationships from this extraction result @@ -864,8 +628,13 @@ async def rebuild_knowledge_from_chunks( chunk_entities[chunk_id][entity_name].extend(entity_list) else: # Compare description lengths and keep the better one - existing_desc_len = len(chunk_entities[chunk_id][entity_name][0].get('description', '') or '') - new_desc_len = len(entity_list[0].get('description', '') or '') + existing_desc_len = len( + chunk_entities[chunk_id][entity_name][0].get( + "description", "" + ) + or "" + ) + new_desc_len = len(entity_list[0].get("description", "") or "") if new_desc_len > existing_desc_len: # Replace with the new entity that has longer description @@ -882,8 +651,13 @@ async def rebuild_knowledge_from_chunks( chunk_relationships[chunk_id][rel_key].extend(rel_list) else: # Compare description lengths and keep the better one - existing_desc_len = len(chunk_relationships[chunk_id][rel_key][0].get('description', '') or '') - new_desc_len = len(rel_list[0].get('description', '') or '') + existing_desc_len = len( + chunk_relationships[chunk_id][rel_key][0].get( + "description", "" + ) + or "" + ) + new_desc_len = len(rel_list[0].get("description", "") or "") if new_desc_len > existing_desc_len: # Replace with the new relationship that has longer description @@ -891,16 +665,18 @@ async def rebuild_knowledge_from_chunks( # Otherwise keep existing version except Exception as e: - status_message = f'Failed to parse cached extraction result for chunk {chunk_id}: {e}' + status_message = ( + f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" + ) logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) continue # Get max async tasks limit from global_config for semaphore control - graph_max_async = global_config.get('llm_model_max_async', 4) * 2 + graph_max_async = global_config.get("llm_model_max_async", 4) * 2 semaphore = asyncio.Semaphore(graph_max_async) # Counters for tracking progress @@ -912,9 +688,11 @@ async def rebuild_knowledge_from_chunks( async def _locked_rebuild_entity(entity_name, chunk_ids): nonlocal rebuilt_entities_count, failed_entities_count async with semaphore: - workspace = global_config.get('workspace', '') - namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB' - async with get_storage_keyed_lock([entity_name], namespace=namespace, enable_logging=False): + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + async with get_storage_keyed_lock( + [entity_name], namespace=namespace, enable_logging=False + ): try: await _rebuild_single_entity( knowledge_graph_inst=knowledge_graph_inst, @@ -929,18 +707,18 @@ async def rebuild_knowledge_from_chunks( rebuilt_entities_count += 1 except Exception as e: failed_entities_count += 1 - status_message = f'Failed to rebuild `{entity_name}`: {e}' + status_message = f"Failed to rebuild `{entity_name}`: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) async def _locked_rebuild_relationship(src, tgt, chunk_ids): nonlocal rebuilt_relationships_count, failed_relationships_count async with semaphore: - workspace = global_config.get('workspace', '') - namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB' + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" # Sort src and tgt to ensure order-independent lock key generation sorted_key_parts = sorted([src, tgt]) async with get_storage_keyed_lock( @@ -967,12 +745,12 @@ async def rebuild_knowledge_from_chunks( rebuilt_relationships_count += 1 except Exception as e: failed_relationships_count += 1 - status_message = f'Failed to rebuild `{src}`~`{tgt}`: {e}' + status_message = f"Failed to rebuild `{src}`~`{tgt}`: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) # Create tasks for parallel processing tasks = [] @@ -988,12 +766,12 @@ async def rebuild_knowledge_from_chunks( tasks.append(task) # Log parallel processing start - status_message = f'Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})' + status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) # Execute all tasks in parallel with semaphore control and early failure detection done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) @@ -1028,15 +806,15 @@ async def rebuild_knowledge_from_chunks( raise first_exception # Final status report - status_message = f'KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully.' + status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." if failed_entities_count > 0 or failed_relationships_count > 0: - status_message += f' Failed: {failed_entities_count} entities, {failed_relationships_count} relationships.' + status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships." logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) async def _get_cached_extraction_results( @@ -1070,14 +848,14 @@ async def _get_cached_extraction_results( chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) for chunk_data in chunk_data_list: if chunk_data and isinstance(chunk_data, dict): - llm_cache_list = chunk_data.get('llm_cache_list', []) + llm_cache_list = chunk_data.get("llm_cache_list", []) if llm_cache_list: all_cache_ids.update(llm_cache_list) else: - logger.warning(f'Chunk data is invalid or None: {chunk_data}') + logger.warning(f"Chunk data is invalid or None: {chunk_data}") if not all_cache_ids: - logger.warning(f'No LLM cache IDs found for {len(chunk_ids)} chunk IDs') + logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") return cached_results # Batch get LLM cache entries @@ -1089,12 +867,14 @@ async def _get_cached_extraction_results( if ( cache_entry is not None and isinstance(cache_entry, dict) - and cache_entry.get('cache_type') == 'extract' - and cache_entry.get('chunk_id') in chunk_ids + and cache_entry.get("cache_type") == "extract" + and cache_entry.get("chunk_id") in chunk_ids ): - chunk_id = cache_entry['chunk_id'] - extraction_result = cache_entry['return'] - create_time = cache_entry.get('create_time', 0) # Get creation time, default to 0 + chunk_id = cache_entry["chunk_id"] + extraction_result = cache_entry["return"] + create_time = cache_entry.get( + "create_time", 0 + ) # Get creation time, default to 0 valid_entries += 1 # Support multiple LLM caches per chunk @@ -1112,14 +892,18 @@ async def _get_cached_extraction_results( chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1] # Sort cached_results by the earliest create_time of each chunk - sorted_chunk_ids = sorted(chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id]) + sorted_chunk_ids = sorted( + chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id] + ) # Rebuild cached_results in sorted order sorted_cached_results = {} for chunk_id in sorted_chunk_ids: sorted_cached_results[chunk_id] = cached_results[chunk_id] - logger.info(f'Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results') + logger.info( + f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results" + ) return sorted_cached_results # each item: list(extraction_result, create_time) @@ -1127,9 +911,9 @@ async def _process_extraction_result( result: str, chunk_key: str, timestamp: int, - file_path: str = 'unknown_source', - tuple_delimiter: str = '<|#|>', - completion_delimiter: str = '<|COMPLETE|>', + file_path: str = "unknown_source", + tuple_delimiter: str = "<|#|>", + completion_delimiter: str = "<|COMPLETE|>", ) -> tuple[dict, dict]: """Process a single extraction result (either initial or gleaning) Args: @@ -1146,95 +930,50 @@ async def _process_extraction_result( maybe_edges = defaultdict(list) if completion_delimiter not in result: - logger.warning(f'{chunk_key}: Complete delimiter can not be found in extraction result') + logger.warning( + f"{chunk_key}: Complete delimiter can not be found in extraction result" + ) - # Split LLM output result to records by "\n" or other common separators - # Some models use <|#|> between records instead of newlines + # Split LLL output result to records by "\n" records = split_string_by_multi_markers( result, - ['\n', completion_delimiter, completion_delimiter.lower()], + ["\n", completion_delimiter, completion_delimiter.lower()], ) - # Additional split: handle models that output records separated by tuple_delimiter + "entity" or "relation" - # e.g., "entity<|#|>A<|#|>type<|#|>desc<|#|>entity<|#|>B<|#|>type<|#|>desc" - expanded_records = [] - for record in records: - record = record.strip() - if not record: - continue - # Split by patterns that indicate a new entity/relation record starting - # Pattern: <|#|>entity<|#|> or <|#|>relation<|#|> - sub_records = split_string_by_multi_markers( - record, - [ - f'{tuple_delimiter}entity{tuple_delimiter}', - f'{tuple_delimiter}relation{tuple_delimiter}', - f'{tuple_delimiter}relationship{tuple_delimiter}', - ], - ) - for i, sub in enumerate(sub_records): - sub = sub.strip() - if not sub: - continue - # First sub-record: check if it already starts with entity/relation - if i == 0: - if sub.lower().startswith(('entity', 'relation')): - expanded_records.append(sub) - else: - # Might be partial, try to recover by checking content - # If it looks like entity fields (has enough delimiters), prefix with 'entity' - if sub.count(tuple_delimiter) >= 2: - expanded_records.append(f'entity{tuple_delimiter}{sub}') - else: - expanded_records.append(sub) - else: - # Subsequent sub-records lost their prefix during split, restore it - # Determine if it's entity or relation based on field count - # entity: name, type, desc (3 fields after split = 2 delimiters) - # relation: source, target, keywords, desc (4 fields = 3 delimiters) - delimiter_count = sub.count(tuple_delimiter) - if delimiter_count >= 3: - expanded_records.append(f'relation{tuple_delimiter}{sub}') - else: - expanded_records.append(f'entity{tuple_delimiter}{sub}') - - records = expanded_records if expanded_records else records - # Fix LLM output format error which use tuple_delimiter to seperate record instead of "\n" fixed_records = [] for record in records: record = record.strip() - if not record: + if record is None: continue - # If record already starts with entity/relation, keep it as-is - if record.lower().startswith(('entity', 'relation')): - fixed_records.append(record) - continue - # Otherwise try to recover malformed records - entity_records = split_string_by_multi_markers(record, [f'{tuple_delimiter}entity{tuple_delimiter}']) + entity_records = split_string_by_multi_markers( + record, [f"{tuple_delimiter}entity{tuple_delimiter}"] + ) for entity_record in entity_records: - if not entity_record.startswith('entity') and not entity_record.startswith('relation'): - entity_record = f'entity<|{entity_record}' - entity_relation_records = split_string_by_multi_markers( - # treat "relationship" and "relation" interchangeable - entity_record, - [ - f'{tuple_delimiter}relationship{tuple_delimiter}', - f'{tuple_delimiter}relation{tuple_delimiter}', - ], - ) - for entity_relation_record in entity_relation_records: - if not entity_relation_record.startswith('entity') and not entity_relation_record.startswith( - 'relation' - ): - entity_relation_record = f'relation{tuple_delimiter}{entity_relation_record}' - fixed_records.append(entity_relation_record) - else: - fixed_records.append(entity_record) + if not entity_record.startswith("entity") and not entity_record.startswith( + "relation" + ): + entity_record = f"entity<|{entity_record}" + entity_relation_records = split_string_by_multi_markers( + # treat "relationship" and "relation" interchangeable + entity_record, + [ + f"{tuple_delimiter}relationship{tuple_delimiter}", + f"{tuple_delimiter}relation{tuple_delimiter}", + ], + ) + for entity_relation_record in entity_relation_records: + if not entity_relation_record.startswith( + "entity" + ) and not entity_relation_record.startswith("relation"): + entity_relation_record = ( + f"relation{tuple_delimiter}{entity_relation_record}" + ) + fixed_records = fixed_records + [entity_relation_record] if len(fixed_records) != len(records): - logger.debug( - f'{chunk_key}: Recovered {len(fixed_records)} records from {len(records)} raw records' + logger.warning( + f"{chunk_key}: LLM output format error; find LLM use {tuple_delimiter} as record seperators instead new-line" ) for record in fixed_records: @@ -1248,20 +987,24 @@ async def _process_extraction_result( if delimiter_core != delimiter_core.lower(): # change delimiter_core to lower case, and fix again delimiter_core = delimiter_core.lower() - record = fix_tuple_delimiter_corruption(record, delimiter_core, tuple_delimiter) + record = fix_tuple_delimiter_corruption( + record, delimiter_core, tuple_delimiter + ) record_attributes = split_string_by_multi_markers(record, [tuple_delimiter]) # Try to parse as entity - entity_data = await _handle_single_entity_extraction(record_attributes, chunk_key, timestamp, file_path) + entity_data = await _handle_single_entity_extraction( + record_attributes, chunk_key, timestamp, file_path + ) if entity_data is not None: truncated_name = _truncate_entity_identifier( - entity_data['entity_name'], + entity_data["entity_name"], DEFAULT_ENTITY_NAME_MAX_LENGTH, chunk_key, - 'Entity name', + "Entity name", ) - entity_data['entity_name'] = truncated_name + entity_data["entity_name"] = truncated_name maybe_nodes[truncated_name].append(entity_data) continue @@ -1271,19 +1014,19 @@ async def _process_extraction_result( ) if relationship_data is not None: truncated_source = _truncate_entity_identifier( - relationship_data['src_id'], + relationship_data["src_id"], DEFAULT_ENTITY_NAME_MAX_LENGTH, chunk_key, - 'Relation entity', + "Relation entity", ) truncated_target = _truncate_entity_identifier( - relationship_data['tgt_id'], + relationship_data["tgt_id"], DEFAULT_ENTITY_NAME_MAX_LENGTH, chunk_key, - 'Relation entity', + "Relation entity", ) - relationship_data['src_id'] = truncated_source - relationship_data['tgt_id'] = truncated_target + relationship_data["src_id"] = truncated_source + relationship_data["tgt_id"] = truncated_target maybe_edges[(truncated_source, truncated_target)].append(relationship_data) return dict(maybe_nodes), dict(maybe_edges) @@ -1308,7 +1051,11 @@ async def _rebuild_from_extraction_result( # Get chunk data for file_path from storage chunk_data = await text_chunks_storage.get_by_id(chunk_id) - file_path = chunk_data.get('file_path', 'unknown_source') if chunk_data else 'unknown_source' + file_path = ( + chunk_data.get("file_path", "unknown_source") + if chunk_data + else "unknown_source" + ) # Call the shared processing function return await _process_extraction_result( @@ -1316,8 +1063,8 @@ async def _rebuild_from_extraction_result( chunk_id, timestamp, file_path, - tuple_delimiter=PROMPTS['DEFAULT_TUPLE_DELIMITER'], - completion_delimiter=PROMPTS['DEFAULT_COMPLETION_DELIMITER'], + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], ) @@ -1328,7 +1075,7 @@ async def _rebuild_single_entity( chunk_ids: list[str], chunk_entities: dict, llm_response_cache: BaseKVStorage, - global_config: dict[str, Any], + global_config: dict[str, str], entity_chunks_storage: BaseKVStorage | None = None, pipeline_status: dict | None = None, pipeline_status_lock=None, @@ -1346,49 +1093,49 @@ async def _rebuild_single_entity( entity_type: str, file_paths: list[str], source_chunk_ids: list[str], - truncation_info: str = '', + truncation_info: str = "", ): try: # Update entity in graph storage (critical path) updated_entity_data = { **current_entity, - 'description': final_description, - 'entity_type': entity_type, - 'source_id': GRAPH_FIELD_SEP.join(source_chunk_ids), - 'file_path': GRAPH_FIELD_SEP.join(file_paths) + "description": final_description, + "entity_type": entity_type, + "source_id": GRAPH_FIELD_SEP.join(source_chunk_ids), + "file_path": GRAPH_FIELD_SEP.join(file_paths) if file_paths - else current_entity.get('file_path', 'unknown_source'), - 'created_at': int(time.time()), - 'truncate': truncation_info, + else current_entity.get("file_path", "unknown_source"), + "created_at": int(time.time()), + "truncate": truncation_info, } await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data) # Update entity in vector database (equally critical) - entity_vdb_id = compute_mdhash_id(entity_name, prefix='ent-') - entity_content = f'{entity_name}\n{final_description}' + entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-") + entity_content = f"{entity_name}\n{final_description}" vdb_data = { entity_vdb_id: { - 'content': entity_content, - 'entity_name': entity_name, - 'source_id': updated_entity_data['source_id'], - 'description': final_description, - 'entity_type': entity_type, - 'file_path': updated_entity_data['file_path'], + "content": entity_content, + "entity_name": entity_name, + "source_id": updated_entity_data["source_id"], + "description": final_description, + "entity_type": entity_type, + "file_path": updated_entity_data["file_path"], } } # Use safe operation wrapper - VDB failure must throw exception await safe_vdb_operation_with_exception( operation=lambda: entities_vdb.upsert(vdb_data), - operation_name='rebuild_entity_upsert', + operation_name="rebuild_entity_upsert", entity_name=entity_name, max_retries=3, retry_delay=0.1, ) except Exception as e: - error_msg = f'Failed to update entity storage for `{entity_name}`: {e}' + error_msg = f"Failed to update entity storage for `{entity_name}`: {e}" logger.error(error_msg) raise # Re-raise exception @@ -1399,19 +1146,21 @@ async def _rebuild_single_entity( await entity_chunks_storage.upsert( { entity_name: { - 'chunk_ids': normalized_chunk_ids, - 'count': len(normalized_chunk_ids), + "chunk_ids": normalized_chunk_ids, + "count": len(normalized_chunk_ids), } } ) - limit_method = global_config.get('source_ids_limit_method') or SOURCE_IDS_LIMIT_METHOD_KEEP + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) limited_chunk_ids = apply_source_ids_limit( normalized_chunk_ids, - global_config['max_source_ids_per_entity'], + global_config["max_source_ids_per_entity"], limit_method, - identifier=f'`{entity_name}`', + identifier=f"`{entity_name}`", ) # Collect all entity data from relevant (limited) chunks @@ -1421,12 +1170,14 @@ async def _rebuild_single_entity( all_entity_data.extend(chunk_entities[chunk_id][entity_name]) if not all_entity_data: - logger.warning(f'No entity data found for `{entity_name}`, trying to rebuild from relationships') + logger.warning( + f"No entity data found for `{entity_name}`, trying to rebuild from relationships" + ) # Get all edges connected to this entity edges = await knowledge_graph_inst.get_node_edges(entity_name) if not edges: - logger.warning(f'No relations attached to entity `{entity_name}`') + logger.warning(f"No relations attached to entity `{entity_name}`") return # Collect relationship data to extract entity information @@ -1437,11 +1188,11 @@ async def _rebuild_single_entity( for src_id, tgt_id in edges: edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id) if edge_data: - if edge_data.get('description'): - relationship_descriptions.append(edge_data['description']) + if edge_data.get("description"): + relationship_descriptions.append(edge_data["description"]) - if edge_data.get('file_path'): - edge_file_paths = edge_data['file_path'].split(GRAPH_FIELD_SEP) + if edge_data.get("file_path"): + edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) file_paths.update(edge_file_paths) # deduplicate descriptions @@ -1450,7 +1201,7 @@ async def _rebuild_single_entity( # Generate final description from relationships or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( - 'Entity', + "Entity", entity_name, description_list, GRAPH_FIELD_SEP, @@ -1458,13 +1209,13 @@ async def _rebuild_single_entity( llm_response_cache=llm_response_cache, ) else: - final_description = current_entity.get('description', '') + final_description = current_entity.get("description", "") - entity_type = current_entity.get('entity_type', 'UNKNOWN') + entity_type = current_entity.get("entity_type", "UNKNOWN") await _update_entity_storage( final_description, entity_type, - list(file_paths), + file_paths, limited_chunk_ids, ) return @@ -1476,20 +1227,22 @@ async def _rebuild_single_entity( seen_paths = set() for entity_data in all_entity_data: - if entity_data.get('description'): - descriptions.append(entity_data['description']) - if entity_data.get('entity_type'): - entity_types.append(entity_data['entity_type']) - if entity_data.get('file_path'): - file_path = entity_data['file_path'] + if entity_data.get("description"): + descriptions.append(entity_data["description"]) + if entity_data.get("entity_type"): + entity_types.append(entity_data["entity_type"]) + if entity_data.get("file_path"): + file_path = entity_data["file_path"] if file_path and file_path not in seen_paths: file_paths_list.append(file_path) seen_paths.add(file_path) # Apply MAX_FILE_PATHS limit - max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) - limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD) + max_file_paths = global_config.get("max_file_paths") + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + limit_method = global_config.get("source_ids_limit_method") original_count = len(file_paths_list) if original_count > max_file_paths: @@ -1500,8 +1253,12 @@ async def _rebuild_single_entity( # KEEP: keep head (earliest), discard tail file_paths_list = file_paths_list[:max_file_paths] - file_paths_list.append(f'...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})') - logger.info(f'Limited `{entity_name}`: file_path {original_count} -> {max_file_paths} ({limit_method})') + file_paths_list.append( + f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})" + ) + logger.info( + f"Limited `{entity_name}`: file_path {original_count} -> {max_file_paths} ({limit_method})" + ) # Remove duplicates while preserving order description_list = list(dict.fromkeys(descriptions)) @@ -1509,13 +1266,15 @@ async def _rebuild_single_entity( # Get most common entity type entity_type = ( - max(set(entity_types), key=entity_types.count) if entity_types else current_entity.get('entity_type', 'UNKNOWN') + max(set(entity_types), key=entity_types.count) + if entity_types + else current_entity.get("entity_type", "UNKNOWN") ) # Generate final description from entities or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( - 'Entity', + "Entity", entity_name, description_list, GRAPH_FIELD_SEP, @@ -1523,12 +1282,14 @@ async def _rebuild_single_entity( llm_response_cache=llm_response_cache, ) else: - final_description = current_entity.get('description', '') + final_description = current_entity.get("description", "") if len(limited_chunk_ids) < len(normalized_chunk_ids): - truncation_info = f'{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}' + truncation_info = ( + f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}" + ) else: - truncation_info = '' + truncation_info = "" await _update_entity_storage( final_description, @@ -1539,15 +1300,15 @@ async def _rebuild_single_entity( ) # Log rebuild completion with truncation info - status_message = f'Rebuild `{entity_name}` from {len(chunk_ids)} chunks' + status_message = f"Rebuild `{entity_name}` from {len(chunk_ids)} chunks" if truncation_info: - status_message += f' ({truncation_info})' + status_message += f" ({truncation_info})" logger.info(status_message) # Update pipeline status if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) async def _rebuild_single_relationship( @@ -1559,7 +1320,7 @@ async def _rebuild_single_relationship( chunk_ids: list[str], chunk_relationships: dict, llm_response_cache: BaseKVStorage, - global_config: dict[str, Any], + global_config: dict[str, str], relation_chunks_storage: BaseKVStorage | None = None, entity_chunks_storage: BaseKVStorage | None = None, pipeline_status: dict | None = None, @@ -1584,18 +1345,20 @@ async def _rebuild_single_relationship( await relation_chunks_storage.upsert( { storage_key: { - 'chunk_ids': normalized_chunk_ids, - 'count': len(normalized_chunk_ids), + "chunk_ids": normalized_chunk_ids, + "count": len(normalized_chunk_ids), } } ) - limit_method = global_config.get('source_ids_limit_method') or SOURCE_IDS_LIMIT_METHOD_KEEP + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) limited_chunk_ids = apply_source_ids_limit( normalized_chunk_ids, - global_config['max_source_ids_per_relation'], + global_config["max_source_ids_per_relation"], limit_method, - identifier=f'`{src}`~`{tgt}`', + identifier=f"`{src}`~`{tgt}`", ) # Collect all relationship data from relevant chunks @@ -1605,10 +1368,12 @@ async def _rebuild_single_relationship( # Check both (src, tgt) and (tgt, src) since relationships can be bidirectional for edge_key in [(src, tgt), (tgt, src)]: if edge_key in chunk_relationships[chunk_id]: - all_relationship_data.extend(chunk_relationships[chunk_id][edge_key]) + all_relationship_data.extend( + chunk_relationships[chunk_id][edge_key] + ) if not all_relationship_data: - logger.warning(f'No relation data found for `{src}-{tgt}`') + logger.warning(f"No relation data found for `{src}-{tgt}`") return # Merge descriptions and keywords @@ -1619,22 +1384,24 @@ async def _rebuild_single_relationship( seen_paths = set() for rel_data in all_relationship_data: - if rel_data.get('description'): - descriptions.append(rel_data['description']) - if rel_data.get('keywords'): - keywords.append(rel_data['keywords']) - if rel_data.get('weight'): - weights.append(rel_data['weight']) - if rel_data.get('file_path'): - file_path = rel_data['file_path'] + if rel_data.get("description"): + descriptions.append(rel_data["description"]) + if rel_data.get("keywords"): + keywords.append(rel_data["keywords"]) + if rel_data.get("weight"): + weights.append(rel_data["weight"]) + if rel_data.get("file_path"): + file_path = rel_data["file_path"] if file_path and file_path not in seen_paths: file_paths_list.append(file_path) seen_paths.add(file_path) # Apply count limit - max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) - limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD) + max_file_paths = global_config.get("max_file_paths") + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + limit_method = global_config.get("source_ids_limit_method") original_count = len(file_paths_list) if original_count > max_file_paths: @@ -1645,22 +1412,30 @@ async def _rebuild_single_relationship( # KEEP: keep head (earliest), discard tail file_paths_list = file_paths_list[:max_file_paths] - file_paths_list.append(f'...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})') - logger.info(f'Limited `{src}`~`{tgt}`: file_path {original_count} -> {max_file_paths} ({limit_method})') + file_paths_list.append( + f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})" + ) + logger.info( + f"Limited `{src}`~`{tgt}`: file_path {original_count} -> {max_file_paths} ({limit_method})" + ) # Remove duplicates while preserving order description_list = list(dict.fromkeys(descriptions)) keywords = list(dict.fromkeys(keywords)) - combined_keywords = ', '.join(set(keywords)) if keywords else current_relationship.get('keywords', '') + combined_keywords = ( + ", ".join(set(keywords)) + if keywords + else current_relationship.get("keywords", "") + ) - weight = sum(weights) if weights else current_relationship.get('weight', 1.0) + weight = sum(weights) if weights else current_relationship.get("weight", 1.0) # Generate final description from relations or fallback to current if description_list: final_description, _ = await _handle_entity_relation_summary( - 'Relation', - f'{src}-{tgt}', + "Relation", + f"{src}-{tgt}", description_list, GRAPH_FIELD_SEP, global_config, @@ -1668,47 +1443,51 @@ async def _rebuild_single_relationship( ) else: # fallback to keep current(unchanged) - final_description = current_relationship.get('description', '') + final_description = current_relationship.get("description", "") if len(limited_chunk_ids) < len(normalized_chunk_ids): - truncation_info = f'{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}' + truncation_info = ( + f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}" + ) else: - truncation_info = '' + truncation_info = "" # Update relationship in graph storage updated_relationship_data = { **current_relationship, - 'description': final_description if final_description else current_relationship.get('description', ''), - 'keywords': combined_keywords, - 'weight': weight, - 'source_id': GRAPH_FIELD_SEP.join(limited_chunk_ids), - 'file_path': GRAPH_FIELD_SEP.join([fp for fp in file_paths_list if fp]) + "description": final_description + if final_description + else current_relationship.get("description", ""), + "keywords": combined_keywords, + "weight": weight, + "source_id": GRAPH_FIELD_SEP.join(limited_chunk_ids), + "file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths_list if fp]) if file_paths_list - else current_relationship.get('file_path', 'unknown_source'), - 'truncate': truncation_info, + else current_relationship.get("file_path", "unknown_source"), + "truncate": truncation_info, } # Ensure both endpoint nodes exist before writing the edge back # (certain storage backends require pre-existing nodes). node_description = ( - updated_relationship_data['description'] - if updated_relationship_data.get('description') - else current_relationship.get('description', '') + updated_relationship_data["description"] + if updated_relationship_data.get("description") + else current_relationship.get("description", "") ) - node_source_id = updated_relationship_data.get('source_id', '') - node_file_path = updated_relationship_data.get('file_path', 'unknown_source') + node_source_id = updated_relationship_data.get("source_id", "") + node_file_path = updated_relationship_data.get("file_path", "unknown_source") for node_id in {src, tgt}: if not (await knowledge_graph_inst.has_node(node_id)): node_created_at = int(time.time()) node_data = { - 'entity_id': node_id, - 'source_id': node_source_id, - 'description': node_description, - 'entity_type': 'UNKNOWN', - 'file_path': node_file_path, - 'created_at': node_created_at, - 'truncate': '', + "entity_id": node_id, + "source_id": node_source_id, + "description": node_description, + "entity_type": "UNKNOWN", + "file_path": node_file_path, + "created_at": node_created_at, + "truncate": "", } await knowledge_graph_inst.upsert_node(node_id, node_data=node_data) @@ -1717,28 +1496,28 @@ async def _rebuild_single_relationship( await entity_chunks_storage.upsert( { node_id: { - 'chunk_ids': limited_chunk_ids, - 'count': len(limited_chunk_ids), + "chunk_ids": limited_chunk_ids, + "count": len(limited_chunk_ids), } } ) # Update entity_vdb for the newly created entity if entities_vdb is not None: - entity_vdb_id = compute_mdhash_id(node_id, prefix='ent-') - entity_content = f'{node_id}\n{node_description}' + entity_vdb_id = compute_mdhash_id(node_id, prefix="ent-") + entity_content = f"{node_id}\n{node_description}" vdb_data = { entity_vdb_id: { - 'content': entity_content, - 'entity_name': node_id, - 'source_id': node_source_id, - 'entity_type': 'UNKNOWN', - 'file_path': node_file_path, + "content": entity_content, + "entity_name": node_id, + "source_id": node_source_id, + "entity_type": "UNKNOWN", + "file_path": node_file_path, } } await safe_vdb_operation_with_exception( operation=lambda payload=vdb_data: entities_vdb.upsert(payload), - operation_name='rebuild_added_entity_upsert', + operation_name="rebuild_added_entity_upsert", entity_name=node_id, max_retries=3, retry_delay=0.1, @@ -1751,211 +1530,64 @@ async def _rebuild_single_relationship( if src > tgt: src, tgt = tgt, src try: - rel_vdb_id = compute_mdhash_id(src + tgt, prefix='rel-') - rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix='rel-') + rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") + rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") # Delete old vector records first (both directions to be safe) try: await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) except Exception as e: - logger.debug(f'Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}') + logger.debug( + f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" + ) # Insert new vector record - rel_content = f'{combined_keywords}\t{src}\n{tgt}\n{final_description}' + rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}" vdb_data = { rel_vdb_id: { - 'src_id': src, - 'tgt_id': tgt, - 'source_id': updated_relationship_data['source_id'], - 'content': rel_content, - 'keywords': combined_keywords, - 'description': final_description, - 'weight': weight, - 'file_path': updated_relationship_data['file_path'], + "src_id": src, + "tgt_id": tgt, + "source_id": updated_relationship_data["source_id"], + "content": rel_content, + "keywords": combined_keywords, + "description": final_description, + "weight": weight, + "file_path": updated_relationship_data["file_path"], } } # Use safe operation wrapper - VDB failure must throw exception await safe_vdb_operation_with_exception( operation=lambda: relationships_vdb.upsert(vdb_data), - operation_name='rebuild_relationship_upsert', - entity_name=f'{src}-{tgt}', + operation_name="rebuild_relationship_upsert", + entity_name=f"{src}-{tgt}", max_retries=3, retry_delay=0.2, ) except Exception as e: - error_msg = f'Failed to rebuild relationship storage for `{src}-{tgt}`: {e}' + error_msg = f"Failed to rebuild relationship storage for `{src}-{tgt}`: {e}" logger.error(error_msg) raise # Re-raise exception # Log rebuild completion with truncation info - status_message = f'Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks' + status_message = f"Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks" if truncation_info: - status_message += f' ({truncation_info})' - elif len(limited_chunk_ids) < len(normalized_chunk_ids): - status_message += f' ({limit_method}:{len(limited_chunk_ids)}/{len(normalized_chunk_ids)})' + status_message += f" ({truncation_info})" + # Add truncation info from apply_source_ids_limit if truncation occurred + if len(limited_chunk_ids) < len(normalized_chunk_ids): + truncation_info = ( + f" ({limit_method}:{len(limited_chunk_ids)}/{len(normalized_chunk_ids)})" + ) + status_message += truncation_info logger.info(status_message) # Update pipeline status if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) - - -def _has_different_numeric_suffix(name_a: str, name_b: str) -> bool: - """Check if two names have different numeric components. - - This prevents false fuzzy matches between entities that differ only by number, - such as "Interleukin-4" vs "Interleukin-13" (88.9% similar but semantically distinct). - - Scientific/medical entities often use numbers as key identifiers: - - Interleukins: IL-4, IL-13, IL-17 - - Drug phases: Phase 1, Phase 2, Phase 3 - - Receptor types: Type 1, Type 2 - - Versions: v1.0, v2.0 - - Args: - name_a: First entity name - name_b: Second entity name - - Returns: - True if both names contain numbers but the numbers differ, False otherwise. - """ - # Extract all numeric patterns (integers and decimals) - pattern = r'(\d+(?:\.\d+)?)' - nums_a = re.findall(pattern, name_a) - nums_b = re.findall(pattern, name_b) - - # If both have numbers and they differ, these are likely distinct entities - return bool(nums_a and nums_b and nums_a != nums_b) - - -async def _build_pre_resolution_map( - entity_names: list[str], - entity_types: dict[str, str], - entity_vdb, - llm_fn, - config: EntityResolutionConfig, -) -> tuple[dict[str, str], dict[str, float]]: - """Build resolution map before parallel processing to prevent race conditions. - - This function resolves entities against each other within the batch (using - instant fuzzy matching) and against existing VDB entries. The resulting map - is applied during parallel entity processing. - - Args: - entity_names: List of entity names to resolve - entity_types: Dict mapping entity names to their types (e.g., "person", "organization"). - Used to prevent fuzzy matching between entities of different types. - entity_vdb: Entity vector database for checking existing entities - llm_fn: LLM function for semantic verification - config: Entity resolution configuration - - Returns: - Tuple of: - - resolution_map: Dict mapping original entity names to their resolved canonical names. - Only entities that need remapping are included. - - confidence_map: Dict mapping alias to confidence score (1.0 for exact, actual - similarity for fuzzy, result.confidence for VDB matches). - """ - resolution_map: dict[str, str] = {} - confidence_map: dict[str, float] = {} - # Track canonical entities with their types: [(name, type), ...] - canonical_entities: list[tuple[str, str]] = [] - - for entity_name in entity_names: - normalized = entity_name.lower().strip() - entity_type = entity_types.get(entity_name, '') - - # Skip if already resolved to something in this batch - if entity_name in resolution_map: - continue - - # Layer 1: Case-insensitive exact match within batch - matched = False - for canonical, _canonical_type in canonical_entities: - if canonical.lower().strip() == normalized: - resolution_map[entity_name] = canonical - confidence_map[entity_name] = 1.0 # Exact match = perfect confidence - logger.debug(f"Pre-resolution (case match): '{entity_name}' → '{canonical}'") - matched = True - break - - if matched: - continue - - # Layer 2: Fuzzy match within batch (catches typos like Dupixant→Dupixent) - # Only enabled when config.fuzzy_pre_resolution_enabled is True. - # Requires: similarity >= threshold AND matching types (or unknown). - if config.fuzzy_pre_resolution_enabled: - for canonical, canonical_type in canonical_entities: - similarity = fuzzy_similarity(entity_name, canonical) - if similarity >= config.fuzzy_threshold: - # Type compatibility check: skip if types differ and both known. - # Empty/unknown types are treated as compatible to avoid - # blocking legitimate matches when type info is incomplete. - types_compatible = not entity_type or not canonical_type or entity_type == canonical_type - if not types_compatible: - logger.debug( - f'Pre-resolution (fuzzy {similarity:.2f}): SKIPPED ' - f"'{entity_name}' ({entity_type}) → " - f"'{canonical}' ({canonical_type}) - type mismatch" - ) - continue - - # Numeric suffix check: skip if names have different numbers - # This prevents false matches like "Interleukin-4" → "Interleukin-13" - # where fuzzy similarity is high (88.9%) but entities are distinct - if _has_different_numeric_suffix(entity_name, canonical): - logger.debug( - f'Pre-resolution (fuzzy {similarity:.2f}): SKIPPED ' - f"'{entity_name}' → '{canonical}' - different numeric suffix" - ) - continue - - # Accept the fuzzy match - emit warning for review - resolution_map[entity_name] = canonical - confidence_map[entity_name] = similarity # Use actual similarity score - etype_display = entity_type or 'unknown' - ctype_display = canonical_type or 'unknown' - logger.warning( - f"Fuzzy pre-resolution accepted: '{entity_name}' → " - f"'{canonical}' (similarity={similarity:.3f}, " - f'types: {etype_display}→{ctype_display}). ' - f'Review for correctness; adjust fuzzy_threshold or ' - f'disable fuzzy_pre_resolution_enabled if needed.' - ) - matched = True - break - - if matched: - continue - - # Layer 3: Check existing VDB for cross-document deduplication - if entity_vdb and llm_fn: - try: - result = await resolve_entity_with_vdb(entity_name, entity_vdb, llm_fn, config) - if result.action == 'match' and result.matched_entity: - resolution_map[entity_name] = result.matched_entity - confidence_map[entity_name] = result.confidence # Use VDB result confidence - # Add canonical from VDB so batch entities can match it. - # VDB matches don't have type info available, use empty. - canonical_entities.append((result.matched_entity, '')) - logger.debug(f"Pre-resolution (VDB {result.method}): '{entity_name}' → '{result.matched_entity}'") - continue - except Exception as e: - logger.debug(f"Pre-resolution VDB check failed for '{entity_name}': {e}") - - # No match found - this is a new canonical entity - canonical_entities.append((entity_name, entity_type)) - - if resolution_map: - logger.info(f'Pre-resolution: {len(resolution_map)} entities mapped to canonical forms') - - return resolution_map, confidence_map + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) async def _merge_nodes_then_upsert( @@ -1964,152 +1596,39 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage | None, global_config: dict, - pipeline_status: dict | None = None, + pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, entity_chunks_storage: BaseKVStorage | None = None, - pre_resolution_map: dict[str, str] | None = None, - prefetched_nodes: dict[str, dict] | None = None, -) -> tuple[dict, str | None]: - """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert. - - Args: - prefetched_nodes: Optional dict mapping entity names to their existing node data. - If provided, avoids individual get_node() calls for better performance. - - Returns: - Tuple of (node_data, original_entity_name). original_entity_name is set if - entity resolution changed the name (e.g., "Dupixant" → "Dupixent"), - otherwise None. - """ - original_entity_name = entity_name # Track original before resolution - - # Apply pre-resolution map immediately (prevents race conditions in parallel processing) - pre_resolved = False - if pre_resolution_map and entity_name in pre_resolution_map: - entity_name = pre_resolution_map[entity_name] - pre_resolved = True - logger.debug(f"Applied pre-resolution: '{original_entity_name}' → '{entity_name}'") - - # Entity Resolution: Resolve new entity against existing entities - # Skip if already pre-resolved (to avoid redundant VDB queries) - entity_resolution_config_raw = global_config.get('entity_resolution_config') - entity_resolution_config = None - if entity_resolution_config_raw: - # Handle both dict (from asdict() serialization) and EntityResolutionConfig instances - if isinstance(entity_resolution_config_raw, EntityResolutionConfig): - entity_resolution_config = entity_resolution_config_raw - elif isinstance(entity_resolution_config_raw, dict): - try: - entity_resolution_config = EntityResolutionConfig(**entity_resolution_config_raw) - except TypeError as e: - logger.warning( - f'Invalid entity_resolution_config: {e}. ' - f'Config: {entity_resolution_config_raw}. Skipping resolution.' - ) - - # Safely check if entity resolution is enabled, handling both object and dict forms - def _is_resolution_enabled(config) -> bool: - if config is None: - return False - if isinstance(config, dict): - return config.get('enabled', False) - return getattr(config, 'enabled', False) - - # Skip VDB resolution if entity was already pre-resolved (prevents redundant queries) - if _is_resolution_enabled(entity_resolution_config) and entity_vdb and not pre_resolved: - resolution_config = cast(EntityResolutionConfig, entity_resolution_config) - original_name = entity_name - workspace = global_config.get('workspace', '') - # Try knowledge_graph_inst.db first (more reliable), fallback to entity_vdb.db - db = getattr(knowledge_graph_inst, 'db', None) or getattr(entity_vdb, 'db', None) - - # Layer 0: Check alias cache first (PostgreSQL-only - requires db connection) - # Note: Alias caching is only available when using PostgreSQL storage backend - if db is not None: - try: - cached = await get_cached_alias(original_name, db, workspace) - if cached: - canonical, method, _ = cached - logger.debug(f"Alias cache hit: '{original_name}' → '{canonical}' (method: {method})") - entity_name = canonical - except Exception as e: - logger.warning( - f"Entity resolution cache lookup failed for '{original_name}' " - f'(workspace: {workspace}): {type(e).__name__}: {e}. ' - 'Continuing without cache.' - ) - - # Layers 1-3: Full VDB resolution (if not found in cache) - if entity_name == original_name: - llm_fn = global_config.get('llm_model_func') - if llm_fn: - try: - resolution = await resolve_entity_with_vdb( - entity_name, - entity_vdb, - llm_fn, - resolution_config, - ) - if resolution.action == 'match' and resolution.matched_entity: - logger.info( - f"Entity resolution: '{entity_name}' → '{resolution.matched_entity}' " - f'(method: {resolution.method}, confidence: {resolution.confidence:.2f})' - ) - entity_name = resolution.matched_entity - - # Store in alias cache for next time (PostgreSQL-only) - # Note: Alias caching requires PostgreSQL storage backend - if db is not None: - try: - await store_alias( - original_name, - entity_name, - resolution.method, - resolution.confidence, - db, - workspace, - ) - except Exception as e: - logger.warning( - f"Failed to store entity alias '{original_name}' → '{entity_name}' " - f'(workspace: {workspace}): {type(e).__name__}: {e}. ' - 'Resolution succeeded but cache not updated.' - ) - except Exception as e: - logger.warning( - f"Entity resolution failed for '{original_name}' " - f'(workspace: {workspace}): {type(e).__name__}: {e}. ' - 'Continuing with original entity name.' - ) - +): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] already_file_paths = [] - # 1. Get existing node data from knowledge graph (use prefetched if available) - if prefetched_nodes is not None and entity_name in prefetched_nodes: - already_node = prefetched_nodes[entity_name] - else: - # Fallback to individual fetch if not prefetched (e.g., after VDB resolution) - already_node = await knowledge_graph_inst.get_node(entity_name) + # 1. Get existing node data from knowledge graph + already_node = await knowledge_graph_inst.get_node(entity_name) if already_node: - already_entity_types.append(already_node['entity_type']) - already_source_ids.extend(already_node['source_id'].split(GRAPH_FIELD_SEP)) - already_file_paths.extend(already_node['file_path'].split(GRAPH_FIELD_SEP)) - already_description.extend(already_node['description'].split(GRAPH_FIELD_SEP)) + already_entity_types.append(already_node["entity_type"]) + already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP)) + already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP)) + already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP)) - new_source_ids = [dp['source_id'] for dp in nodes_data if dp.get('source_id')] + new_source_ids = [dp["source_id"] for dp in nodes_data if dp.get("source_id")] existing_full_source_ids = [] if entity_chunks_storage is not None: stored_chunks = await entity_chunks_storage.get_by_id(entity_name) if stored_chunks and isinstance(stored_chunks, dict): - existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id] + existing_full_source_ids = [ + chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id + ] if not existing_full_source_ids: - existing_full_source_ids = [chunk_id for chunk_id in already_source_ids if chunk_id] + existing_full_source_ids = [ + chunk_id for chunk_id in already_source_ids if chunk_id + ] # 2. Merging new source ids with existing ones full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids) @@ -2118,23 +1637,20 @@ async def _merge_nodes_then_upsert( await entity_chunks_storage.upsert( { entity_name: { - 'chunk_ids': full_source_ids, - 'count': len(full_source_ids), + "chunk_ids": full_source_ids, + "count": len(full_source_ids), } } ) # 3. Finalize source_id by applying source ids limit - limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD) - max_source_limit = int( - global_config.get('max_source_ids_per_entity', DEFAULT_MAX_SOURCE_IDS_PER_ENTITY) - or DEFAULT_MAX_SOURCE_IDS_PER_ENTITY - ) + limit_method = global_config.get("source_ids_limit_method") + max_source_limit = global_config.get("max_source_ids_per_entity") source_ids = apply_source_ids_limit( full_source_ids, max_source_limit, limit_method, - identifier=f'`{entity_name}`', + identifier=f"`{entity_name}`", ) # 4. Only keep nodes not filter by apply_source_ids_limit if limit_method is KEEP @@ -2142,9 +1658,13 @@ async def _merge_nodes_then_upsert( allowed_source_ids = set(source_ids) filtered_nodes = [] for dp in nodes_data: - source_id = dp.get('source_id') + source_id = dp.get("source_id") # Skip descriptions sourced from chunks dropped by the limitation cap - if source_id and source_id not in allowed_source_ids and source_id not in existing_full_source_ids: + if ( + source_id + and source_id not in allowed_source_ids + and source_id not in existing_full_source_ids + ): continue filtered_nodes.append(dp) nodes_data = filtered_nodes @@ -2158,19 +1678,25 @@ async def _merge_nodes_then_upsert( and not nodes_data ): if already_node: - logger.info(f'Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}') + logger.info( + f"Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" + ) existing_node_data = dict(already_node) - return existing_node_data, None + return existing_node_data else: - logger.error(f'Internal Error: already_node missing for `{entity_name}`') - raise ValueError(f'Internal Error: already_node missing for `{entity_name}`') + logger.error(f"Internal Error: already_node missing for `{entity_name}`") + raise ValueError( + f"Internal Error: already_node missing for `{entity_name}`" + ) # 6.1 Finalize source_id source_id = GRAPH_FIELD_SEP.join(source_ids) # 6.2 Finalize entity type by highest count entity_type = sorted( - Counter([dp['entity_type'] for dp in nodes_data] + already_entity_types).items(), + Counter( + [dp["entity_type"] for dp in nodes_data] + already_entity_types + ).items(), key=lambda x: x[1], reverse=True, )[0][0] @@ -2178,7 +1704,7 @@ async def _merge_nodes_then_upsert( # 7. Deduplicate nodes by description, keeping first occurrence in the same document unique_nodes = {} for dp in nodes_data: - desc = dp.get('description') + desc = dp.get("description") if not desc: continue if desc not in unique_nodes: @@ -2187,25 +1713,25 @@ async def _merge_nodes_then_upsert( # Sort description by timestamp, then by description length when timestamps are the same sorted_nodes = sorted( unique_nodes.values(), - key=lambda x: (x.get('timestamp', 0), -len(x.get('description', ''))), + key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), ) - sorted_descriptions = [dp['description'] for dp in sorted_nodes] + sorted_descriptions = [dp["description"] for dp in sorted_nodes] # Combine already_description with sorted new sorted descriptions description_list = already_description + sorted_descriptions if not description_list: - logger.error(f'Entity {entity_name} has no description') - raise ValueError(f'Entity {entity_name} has no description') + logger.error(f"Entity {entity_name} has no description") + raise ValueError(f"Entity {entity_name} has no description") # Check for cancellation before LLM summary if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during entity summary') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during entity summary") # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( - 'Entity', + "Entity", entity_name, description_list, GRAPH_FIELD_SEP, @@ -2218,12 +1744,14 @@ async def _merge_nodes_then_upsert( seen_paths = set() has_placeholder = False # Indicating file_path has been truncated before - max_file_paths = global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) # Collect from already_file_paths, excluding placeholder for fp in already_file_paths: - if fp and fp.startswith(f'...{file_path_placeholder}'): # Skip placeholders + if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders has_placeholder = True continue if fp and fp not in seen_paths: @@ -2232,28 +1760,36 @@ async def _merge_nodes_then_upsert( # Collect from new data for dp in nodes_data: - file_path_item = dp.get('file_path') + file_path_item = dp.get("file_path") if file_path_item and file_path_item not in seen_paths: file_paths_list.append(file_path_item) seen_paths.add(file_path_item) # Apply count limit if len(file_paths_list) > max_file_paths: - limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP + ) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) # Add + sign to indicate actual file count is higher - original_count_str = f'{len(file_paths_list)}+' if has_placeholder else str(len(file_paths_list)) + original_count_str = ( + f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list)) + ) if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: # FIFO: keep tail (newest), discard head file_paths_list = file_paths_list[-max_file_paths:] - file_paths_list.append(f'...{file_path_placeholder}...(FIFO)') + file_paths_list.append(f"...{file_path_placeholder}...(FIFO)") else: # KEEP: keep head (earliest), discard tail file_paths_list = file_paths_list[:max_file_paths] - file_paths_list.append(f'...{file_path_placeholder}...(KEEP Old)') + file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)") - logger.info(f'Limited `{entity_name}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})') + logger.info( + f"Limited `{entity_name}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})" + ) # Finalize file_path file_path = GRAPH_FIELD_SEP.join(file_paths_list) @@ -2261,73 +1797,75 @@ async def _merge_nodes_then_upsert( num_fragment = len(description_list) already_fragment = len(already_description) if llm_was_used: - status_message = f'LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}' + status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}" else: - status_message = f'Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}' + status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}" - truncation_info = truncation_info_log = '' + truncation_info = truncation_info_log = "" if len(source_ids) < len(full_source_ids): # Add truncation info from apply_source_ids_limit if truncation occurred - truncation_info_log = f'{limit_method} {len(source_ids)}/{len(full_source_ids)}' - truncation_info = truncation_info_log if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO else 'KEEP Old' + truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}" + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + truncation_info = truncation_info_log + else: + truncation_info = "KEEP Old" deduplicated_num = already_fragment + len(nodes_data) - num_fragment - dd_message = '' + dd_message = "" if deduplicated_num > 0: - # Duplicated description detected across multiple chunks for the same entity - dd_message = f'dd {deduplicated_num}' + # Duplicated description detected across multiple trucks for the same entity + dd_message = f"dd {deduplicated_num}" if dd_message or truncation_info_log: - status_message += f' ({", ".join(filter(None, [truncation_info_log, dd_message]))})' + status_message += ( + f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})" + ) # Add message to pipeline satus when merge happens if already_fragment > 0 or llm_was_used: logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) else: logger.debug(status_message) # 11. Update both graph and vector db - node_data = { - 'entity_id': entity_name, - 'entity_type': entity_type, - 'description': description, - 'source_id': source_id, - 'file_path': file_path, - 'created_at': int(time.time()), - 'truncate': truncation_info, - } + node_data = dict( + entity_id=entity_name, + entity_type=entity_type, + description=description, + source_id=source_id, + file_path=file_path, + created_at=int(time.time()), + truncate=truncation_info, + ) await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) - node_data['entity_name'] = entity_name + node_data["entity_name"] = entity_name if entity_vdb is not None: - entity_vdb_id = compute_mdhash_id(str(entity_name), prefix='ent-') - entity_content = f'{entity_name}\n{description}' + entity_vdb_id = compute_mdhash_id(str(entity_name), prefix="ent-") + entity_content = f"{entity_name}\n{description}" data_for_vdb = { entity_vdb_id: { - 'entity_name': entity_name, - 'entity_type': entity_type, - 'content': entity_content, - 'source_id': source_id, - 'file_path': file_path, + "entity_name": entity_name, + "entity_type": entity_type, + "content": entity_content, + "source_id": source_id, + "file_path": file_path, } } await safe_vdb_operation_with_exception( operation=lambda payload=data_for_vdb: entity_vdb.upsert(payload), - operation_name='entity_upsert', + operation_name="entity_upsert", entity_name=entity_name, max_retries=3, retry_delay=0.1, ) - - # Return original name if resolution changed it, None otherwise - resolved_from = original_entity_name if entity_name != original_entity_name else None - return node_data, resolved_from + return node_data async def _merge_edges_then_upsert( @@ -2338,19 +1876,13 @@ async def _merge_edges_then_upsert( relationships_vdb: BaseVectorStorage | None, entity_vdb: BaseVectorStorage | None, global_config: dict, - pipeline_status: dict | None = None, + pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, - added_entities: list | None = None, # New parameter to track entities added during edge processing + added_entities: list = None, # New parameter to track entities added during edge processing relation_chunks_storage: BaseKVStorage | None = None, entity_chunks_storage: BaseKVStorage | None = None, - entity_resolution_map: dict[str, str] | None = None, # Map original→resolved names ): - # Apply entity resolution mapping to edge endpoints - if entity_resolution_map: - src_id = entity_resolution_map.get(src_id, src_id) - tgt_id = entity_resolution_map.get(tgt_id, tgt_id) - if src_id == tgt_id: return None @@ -2359,8 +1891,6 @@ async def _merge_edges_then_upsert( already_source_ids = [] already_description = [] already_keywords = [] - already_relationships = [] - already_types = [] already_file_paths = [] # 1. Get existing edge data from graph storage @@ -2368,50 +1898,50 @@ async def _merge_edges_then_upsert( already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) # Handle the case where get_edge returns None or missing fields if already_edge: - # Get weight with default 1.0 if missing, convert from string if needed - weight_val = already_edge.get('weight', 1.0) - if isinstance(weight_val, str): - try: - weight_val = float(weight_val) - except (ValueError, TypeError): - weight_val = 1.0 - already_weights.append(weight_val) + # Get weight with default 1.0 if missing + already_weights.append(already_edge.get("weight", 1.0)) # Get source_id with empty string default if missing or None - if already_edge.get('source_id') is not None: - already_source_ids.extend(already_edge['source_id'].split(GRAPH_FIELD_SEP)) - - # Get file_path with empty string default if missing or None - if already_edge.get('file_path') is not None: - already_file_paths.extend(already_edge['file_path'].split(GRAPH_FIELD_SEP)) - - # Get description with empty string default if missing or None - if already_edge.get('description') is not None: - already_description.extend(already_edge['description'].split(GRAPH_FIELD_SEP)) - - # Get keywords with empty string default if missing or None - if already_edge.get('keywords') is not None: - already_keywords.extend(split_string_by_multi_markers(already_edge['keywords'], [GRAPH_FIELD_SEP])) - - if already_edge.get('relationship') is not None: - already_relationships.extend( - split_string_by_multi_markers(already_edge['relationship'], [GRAPH_FIELD_SEP, ',']) + if already_edge.get("source_id") is not None: + already_source_ids.extend( + already_edge["source_id"].split(GRAPH_FIELD_SEP) ) - if already_edge.get('type') is not None: - already_types.extend(split_string_by_multi_markers(already_edge['type'], [GRAPH_FIELD_SEP, ','])) + # Get file_path with empty string default if missing or None + if already_edge.get("file_path") is not None: + already_file_paths.extend( + already_edge["file_path"].split(GRAPH_FIELD_SEP) + ) - new_source_ids = [dp['source_id'] for dp in edges_data if dp.get('source_id')] + # Get description with empty string default if missing or None + if already_edge.get("description") is not None: + already_description.extend( + already_edge["description"].split(GRAPH_FIELD_SEP) + ) + + # Get keywords with empty string default if missing or None + if already_edge.get("keywords") is not None: + already_keywords.extend( + split_string_by_multi_markers( + already_edge["keywords"], [GRAPH_FIELD_SEP] + ) + ) + + new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")] storage_key = make_relation_chunk_key(src_id, tgt_id) existing_full_source_ids = [] if relation_chunks_storage is not None: stored_chunks = await relation_chunks_storage.get_by_id(storage_key) if stored_chunks and isinstance(stored_chunks, dict): - existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id] + existing_full_source_ids = [ + chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id + ] if not existing_full_source_ids: - existing_full_source_ids = [chunk_id for chunk_id in already_source_ids if chunk_id] + existing_full_source_ids = [ + chunk_id for chunk_id in already_source_ids if chunk_id + ] # 2. Merge new source ids with existing ones full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids) @@ -2420,32 +1950,37 @@ async def _merge_edges_then_upsert( await relation_chunks_storage.upsert( { storage_key: { - 'chunk_ids': full_source_ids, - 'count': len(full_source_ids), + "chunk_ids": full_source_ids, + "count": len(full_source_ids), } } ) # 3. Finalize source_id by applying source ids limit - limit_method = global_config.get('source_ids_limit_method', DEFAULT_SOURCE_IDS_LIMIT_METHOD) - max_source_limit = int( - global_config.get('max_source_ids_per_relation', DEFAULT_MAX_SOURCE_IDS_PER_RELATION) - or DEFAULT_MAX_SOURCE_IDS_PER_RELATION - ) + limit_method = global_config.get("source_ids_limit_method") + max_source_limit = global_config.get("max_source_ids_per_relation") source_ids = apply_source_ids_limit( full_source_ids, max_source_limit, limit_method, - identifier=f'`{src_id}`~`{tgt_id}`', + identifier=f"`{src_id}`~`{tgt_id}`", ) + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) + # 4. Only keep edges with source_id in the final source_ids list if in KEEP mode if limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP: allowed_source_ids = set(source_ids) filtered_edges = [] for dp in edges_data: - source_id = dp.get('source_id') + source_id = dp.get("source_id") # Skip relationship fragments sourced from chunks dropped by keep oldest cap - if source_id and source_id not in allowed_source_ids and source_id not in existing_full_source_ids: + if ( + source_id + and source_id not in allowed_source_ids + and source_id not in existing_full_source_ids + ): continue filtered_edges.append(dp) edges_data = filtered_edges @@ -2459,53 +1994,44 @@ async def _merge_edges_then_upsert( and not edges_data ): if already_edge: - logger.info(f'Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}') + logger.info( + f"Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" + ) existing_edge_data = dict(already_edge) return existing_edge_data else: - logger.error(f'Internal Error: already_node missing for `{src_id}`~`{tgt_id}`') - raise ValueError(f'Internal Error: already_node missing for `{src_id}`~`{tgt_id}`') + logger.error( + f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`" + ) + raise ValueError( + f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`" + ) # 6.1 Finalize source_id source_id = GRAPH_FIELD_SEP.join(source_ids) # 6.2 Finalize weight by summing new edges and existing weights - weight = sum([dp['weight'] for dp in edges_data] + already_weights) + weight = sum([dp["weight"] for dp in edges_data] + already_weights) # 6.2 Finalize keywords by merging existing and new keywords all_keywords = set() # Process already_keywords (which are comma-separated) for keyword_str in already_keywords: if keyword_str: # Skip empty strings - all_keywords.update(k.strip() for k in keyword_str.split(',') if k.strip()) + all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip()) # Process new keywords from edges_data for edge in edges_data: - if edge.get('keywords'): - all_keywords.update(k.strip() for k in edge['keywords'].split(',') if k.strip()) + if edge.get("keywords"): + all_keywords.update( + k.strip() for k in edge["keywords"].split(",") if k.strip() + ) # Join all unique keywords with commas - keywords = ','.join(sorted(all_keywords)) - - # 6.3 Finalize relationship/type labels from explicit field or fallback to keywords - rel_labels = set() - for edge in edges_data: - if edge.get('relationship'): - rel_labels.update(k.strip() for k in edge['relationship'].split(',') if k.strip()) - rel_labels.update(k.strip() for k in already_relationships if k.strip()) - relationship_label = ','.join(sorted(rel_labels)) - if not relationship_label and keywords: - relationship_label = keywords.split(',')[0] - - type_labels = set() - for edge in edges_data: - if edge.get('type'): - type_labels.update(k.strip() for k in edge['type'].split(',') if k.strip()) - type_labels.update(k.strip() for k in already_types if k.strip()) - type_label = ','.join(sorted(type_labels)) or relationship_label + keywords = ",".join(sorted(all_keywords)) # 7. Deduplicate by description, keeping first occurrence in the same document unique_edges = {} for dp in edges_data: - description_value = dp.get('description') + description_value = dp.get("description") if not description_value: continue if description_value not in unique_edges: @@ -2514,26 +2040,28 @@ async def _merge_edges_then_upsert( # Sort description by timestamp, then by description length (largest to smallest) when timestamps are the same sorted_edges = sorted( unique_edges.values(), - key=lambda x: (x.get('timestamp', 0), -len(x.get('description', ''))), + key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), ) - sorted_descriptions = [dp['description'] for dp in sorted_edges] + sorted_descriptions = [dp["description"] for dp in sorted_edges] # Combine already_description with sorted new descriptions description_list = already_description + sorted_descriptions if not description_list: - logger.error(f'Relation {src_id}~{tgt_id} has no description') - raise ValueError(f'Relation {src_id}~{tgt_id} has no description') + logger.error(f"Relation {src_id}~{tgt_id} has no description") + raise ValueError(f"Relation {src_id}~{tgt_id} has no description") # Check for cancellation before LLM summary if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during relation summary') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation summary" + ) # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( - 'Relation', - f'({src_id}, {tgt_id})', + "Relation", + f"({src_id}, {tgt_id})", description_list, GRAPH_FIELD_SEP, global_config, @@ -2545,13 +2073,15 @@ async def _merge_edges_then_upsert( seen_paths = set() has_placeholder = False # Track if already_file_paths contains placeholder - max_file_paths = global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) # Collect from already_file_paths, excluding placeholder for fp in already_file_paths: # Check if this is a placeholder record - if fp and fp.startswith(f'...{file_path_placeholder}'): # Skip placeholders + if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders has_placeholder = True continue if fp and fp not in seen_paths: @@ -2560,32 +2090,38 @@ async def _merge_edges_then_upsert( # Collect from new data for dp in edges_data: - file_path_item = dp.get('file_path') + file_path_item = dp.get("file_path") if file_path_item and file_path_item not in seen_paths: file_paths_list.append(file_path_item) seen_paths.add(file_path_item) # Apply count limit - max_file_paths = int(global_config.get('max_file_paths', DEFAULT_MAX_FILE_PATHS) or DEFAULT_MAX_FILE_PATHS) + max_file_paths = global_config.get("max_file_paths") if len(file_paths_list) > max_file_paths: - limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP) - file_path_placeholder = global_config.get('file_path_more_placeholder', DEFAULT_FILE_PATH_MORE_PLACEHOLDER) + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP + ) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) # Add + sign to indicate actual file count is higher - original_count_str = f'{len(file_paths_list)}+' if has_placeholder else str(len(file_paths_list)) + original_count_str = ( + f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list)) + ) if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: # FIFO: keep tail (newest), discard head file_paths_list = file_paths_list[-max_file_paths:] - file_paths_list.append(f'...{file_path_placeholder}...(FIFO)') + file_paths_list.append(f"...{file_path_placeholder}...(FIFO)") else: # KEEP: keep head (earliest), discard tail file_paths_list = file_paths_list[:max_file_paths] - file_paths_list.append(f'...{file_path_placeholder}...(KEEP Old)') + file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)") logger.info( - f'Limited `{src_id}`~`{tgt_id}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})' + f"Limited `{src_id}`~`{tgt_id}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})" ) # Finalize file_path file_path = GRAPH_FIELD_SEP.join(file_paths_list) @@ -2594,32 +2130,37 @@ async def _merge_edges_then_upsert( num_fragment = len(description_list) already_fragment = len(already_description) if llm_was_used: - status_message = f'LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}' + status_message = f"LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}" else: - status_message = f'Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}' + status_message = f"Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}" - truncation_info = truncation_info_log = '' + truncation_info = truncation_info_log = "" if len(source_ids) < len(full_source_ids): # Add truncation info from apply_source_ids_limit if truncation occurred - truncation_info_log = f'{limit_method} {len(source_ids)}/{len(full_source_ids)}' - truncation_info = truncation_info_log if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO else 'KEEP Old' + truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}" + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + truncation_info = truncation_info_log + else: + truncation_info = "KEEP Old" deduplicated_num = already_fragment + len(edges_data) - num_fragment - dd_message = '' + dd_message = "" if deduplicated_num > 0: - # Duplicated description detected across multiple chunks for the same entity - dd_message = f'dd {deduplicated_num}' + # Duplicated description detected across multiple trucks for the same entity + dd_message = f"dd {deduplicated_num}" if dd_message or truncation_info_log: - status_message += f' ({", ".join(filter(None, [truncation_info_log, dd_message]))})' + status_message += ( + f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})" + ) # Add message to pipeline satus when merge happens if already_fragment > 0 or llm_was_used: logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) else: logger.debug(status_message) @@ -2632,13 +2173,13 @@ async def _merge_edges_then_upsert( # Node doesn't exist - create new node node_created_at = int(time.time()) node_data = { - 'entity_id': need_insert_id, - 'source_id': source_id, - 'description': description, - 'entity_type': 'UNKNOWN', - 'file_path': file_path, - 'created_at': node_created_at, - 'truncate': '', + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + "created_at": node_created_at, + "truncate": "", } await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) @@ -2649,27 +2190,27 @@ async def _merge_edges_then_upsert( await entity_chunks_storage.upsert( { need_insert_id: { - 'chunk_ids': chunk_ids, - 'count': len(chunk_ids), + "chunk_ids": chunk_ids, + "count": len(chunk_ids), } } ) if entity_vdb is not None: - entity_vdb_id = compute_mdhash_id(need_insert_id, prefix='ent-') - entity_content = f'{need_insert_id}\n{description}' + entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") + entity_content = f"{need_insert_id}\n{description}" vdb_data = { entity_vdb_id: { - 'content': entity_content, - 'entity_name': need_insert_id, - 'source_id': source_id, - 'entity_type': 'UNKNOWN', - 'file_path': file_path, + "content": entity_content, + "entity_name": need_insert_id, + "source_id": source_id, + "entity_type": "UNKNOWN", + "file_path": file_path, } } await safe_vdb_operation_with_exception( operation=lambda payload=vdb_data: entity_vdb.upsert(payload), - operation_name='added_entity_upsert', + operation_name="added_entity_upsert", entity_name=need_insert_id, max_retries=3, retry_delay=0.1, @@ -2678,12 +2219,12 @@ async def _merge_edges_then_upsert( # Track entities added during edge processing if added_entities is not None: entity_data = { - 'entity_name': need_insert_id, - 'entity_type': 'UNKNOWN', - 'description': description, - 'source_id': source_id, - 'file_path': file_path, - 'created_at': node_created_at, + "entity_name": need_insert_id, + "entity_type": "UNKNOWN", + "description": description, + "source_id": source_id, + "file_path": file_path, + "created_at": node_created_at, } added_entities.append(entity_data) else: @@ -2695,68 +2236,87 @@ async def _merge_edges_then_upsert( if entity_chunks_storage is not None: stored_chunks = await entity_chunks_storage.get_by_id(need_insert_id) if stored_chunks and isinstance(stored_chunks, dict): - existing_full_source_ids = [chunk_id for chunk_id in stored_chunks.get('chunk_ids', []) if chunk_id] + existing_full_source_ids = [ + chunk_id + for chunk_id in stored_chunks.get("chunk_ids", []) + if chunk_id + ] # If not in entity_chunks_storage, get from graph database - if not existing_full_source_ids and existing_node.get('source_id'): - existing_full_source_ids = existing_node['source_id'].split(GRAPH_FIELD_SEP) + if not existing_full_source_ids: + if existing_node.get("source_id"): + existing_full_source_ids = existing_node["source_id"].split( + GRAPH_FIELD_SEP + ) # 2. Merge with new source_ids from this relationship - new_source_ids_from_relation = [chunk_id for chunk_id in source_ids if chunk_id] - merged_full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids_from_relation) + new_source_ids_from_relation = [ + chunk_id for chunk_id in source_ids if chunk_id + ] + merged_full_source_ids = merge_source_ids( + existing_full_source_ids, new_source_ids_from_relation + ) # 3. Save merged full list to entity_chunks_storage (conditional) - if entity_chunks_storage is not None and merged_full_source_ids != existing_full_source_ids: + if ( + entity_chunks_storage is not None + and merged_full_source_ids != existing_full_source_ids + ): updated = True await entity_chunks_storage.upsert( { need_insert_id: { - 'chunk_ids': merged_full_source_ids, - 'count': len(merged_full_source_ids), + "chunk_ids": merged_full_source_ids, + "count": len(merged_full_source_ids), } } ) # 4. Apply source_ids limit for graph and vector db - limit_method = global_config.get('source_ids_limit_method', SOURCE_IDS_LIMIT_METHOD_KEEP) - max_source_limit = int( - global_config.get('max_source_ids_per_entity', DEFAULT_MAX_SOURCE_IDS_PER_ENTITY) - or DEFAULT_MAX_SOURCE_IDS_PER_ENTITY + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP ) + max_source_limit = global_config.get("max_source_ids_per_entity") limited_source_ids = apply_source_ids_limit( merged_full_source_ids, max_source_limit, limit_method, - identifier=f'`{need_insert_id}`', + identifier=f"`{need_insert_id}`", ) # 5. Update graph database and vector database with limited source_ids (conditional) limited_source_id_str = GRAPH_FIELD_SEP.join(limited_source_ids) - if limited_source_id_str != existing_node.get('source_id', ''): + if limited_source_id_str != existing_node.get("source_id", ""): updated = True updated_node_data = { **existing_node, - 'source_id': limited_source_id_str, + "source_id": limited_source_id_str, } - await knowledge_graph_inst.upsert_node(need_insert_id, node_data=updated_node_data) + await knowledge_graph_inst.upsert_node( + need_insert_id, node_data=updated_node_data + ) # Update vector database if entity_vdb is not None: - entity_vdb_id = compute_mdhash_id(need_insert_id, prefix='ent-') - entity_content = f'{need_insert_id}\n{existing_node.get("description", "")}' + entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") + entity_content = ( + f"{need_insert_id}\n{existing_node.get('description', '')}" + ) vdb_data = { entity_vdb_id: { - 'content': entity_content, - 'entity_name': need_insert_id, - 'source_id': limited_source_id_str, - 'entity_type': existing_node.get('entity_type', 'UNKNOWN'), - 'file_path': existing_node.get('file_path', 'unknown_source'), + "content": entity_content, + "entity_name": need_insert_id, + "source_id": limited_source_id_str, + "entity_type": existing_node.get("entity_type", "UNKNOWN"), + "file_path": existing_node.get( + "file_path", "unknown_source" + ), } } await safe_vdb_operation_with_exception( operation=lambda payload=vdb_data: entity_vdb.upsert(payload), - operation_name='existing_entity_update', + operation_name="existing_entity_update", entity_name=need_insert_id, max_retries=3, retry_delay=0.1, @@ -2764,74 +2324,70 @@ async def _merge_edges_then_upsert( # 6. Log once at the end if any update occurred if updated: - status_message = f'Chunks appended from relation: `{need_insert_id}`' + status_message = f"Chunks appended from relation: `{need_insert_id}`" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = status_message - pipeline_status['history_messages'].append(status_message) + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) edge_created_at = int(time.time()) await knowledge_graph_inst.upsert_edge( src_id, tgt_id, - edge_data={ - 'weight': str(weight), - 'description': description, - 'keywords': keywords, - 'relationship': relationship_label, - 'type': type_label, - 'source_id': source_id, - 'file_path': file_path, - 'created_at': str(edge_created_at), - 'truncate': truncation_info, - }, + edge_data=dict( + weight=weight, + description=description, + keywords=keywords, + source_id=source_id, + file_path=file_path, + created_at=edge_created_at, + truncate=truncation_info, + ), ) - edge_data = { - 'src_id': src_id, - 'tgt_id': tgt_id, - 'description': description, - 'keywords': keywords, - 'relationship': relationship_label, - 'type': type_label, - 'source_id': source_id, - 'file_path': file_path, - 'created_at': edge_created_at, - 'truncate': truncation_info, - 'weight': weight, - } + edge_data = dict( + src_id=src_id, + tgt_id=tgt_id, + description=description, + keywords=keywords, + source_id=source_id, + file_path=file_path, + created_at=edge_created_at, + truncate=truncation_info, + weight=weight, + ) # Sort src_id and tgt_id to ensure consistent ordering (smaller string first) if src_id > tgt_id: src_id, tgt_id = tgt_id, src_id if relationships_vdb is not None: - rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix='rel-') - rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix='rel-') + rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix="rel-") + rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix="rel-") try: await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) except Exception as e: - logger.debug(f'Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}') - rel_content = f'{keywords}\t{src_id}\n{tgt_id}\n{description}' + logger.debug( + f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" + ) + rel_content = f"{keywords}\t{src_id}\n{tgt_id}\n{description}" vdb_data = { rel_vdb_id: { - 'src_id': src_id, - 'tgt_id': tgt_id, - 'relationship': relationship_label, - 'type': type_label, - 'source_id': source_id, - 'content': rel_content, - 'keywords': keywords, - 'description': description, - 'weight': weight, - 'file_path': file_path, + "src_id": src_id, + "tgt_id": tgt_id, + "source_id": source_id, + "content": rel_content, + "keywords": keywords, + "description": description, + "weight": weight, + "file_path": file_path, } } await safe_vdb_operation_with_exception( operation=lambda payload=vdb_data: relationships_vdb.upsert(payload), - operation_name='relationship_upsert', - entity_name=f'{src_id}-{tgt_id}', + operation_name="relationship_upsert", + entity_name=f"{src_id}-{tgt_id}", max_retries=3, retry_delay=0.2, ) @@ -2844,18 +2400,18 @@ async def merge_nodes_and_edges( knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - global_config: dict[str, Any], - full_entities_storage: BaseKVStorage | None = None, - full_relations_storage: BaseKVStorage | None = None, - doc_id: str | None = None, - pipeline_status: dict | None = None, + global_config: dict[str, str], + full_entities_storage: BaseKVStorage = None, + full_relations_storage: BaseKVStorage = None, + doc_id: str = None, + pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, entity_chunks_storage: BaseKVStorage | None = None, relation_chunks_storage: BaseKVStorage | None = None, current_file_number: int = 0, total_files: int = 0, - file_path: str = 'unknown_source', + file_path: str = "unknown_source", ) -> None: """Two-phase merge: process all entities first, then all relationships @@ -2883,18 +2439,11 @@ async def merge_nodes_and_edges( file_path: File path for logging """ - if full_entities_storage is None or full_relations_storage is None: - raise ValueError('full_entities_storage and full_relations_storage are required for merge operations') - if pipeline_status is None: - pipeline_status = {} - if pipeline_status_lock is None: - pipeline_status_lock = asyncio.Lock() - # Check for cancellation at the start of merge if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during merge phase') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during merge phase") # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) @@ -2913,126 +2462,41 @@ async def merge_nodes_and_edges( total_entities_count = len(all_nodes) total_relations_count = len(all_edges) - log_message = f'Merging stage {current_file_number}/{total_files}: {file_path}' + log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" logger.info(log_message) async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) # Get max async tasks limit from global_config for semaphore control - graph_max_async = global_config.get('llm_model_max_async', 4) * 2 + graph_max_async = global_config.get("llm_model_max_async", 4) * 2 semaphore = asyncio.Semaphore(graph_max_async) - # ===== Pre-Resolution Phase: Build entity resolution map ===== - # This prevents race conditions when parallel workers process similar entities - # IMPORTANT: Include BOTH entity names AND relation endpoints to catch all duplicates - pre_resolution_map: dict[str, str] = {} - entity_resolution_config_raw = global_config.get('entity_resolution_config') - if entity_resolution_config_raw: - # Handle both dict (from asdict() serialization) and EntityResolutionConfig instances - config = None - if isinstance(entity_resolution_config_raw, EntityResolutionConfig): - config = entity_resolution_config_raw - elif isinstance(entity_resolution_config_raw, dict): - try: - config = EntityResolutionConfig(**entity_resolution_config_raw) - except TypeError as e: - logger.warning( - f'Invalid entity_resolution_config: {e}. ' - f'Config: {entity_resolution_config_raw}. Skipping resolution.' - ) - if config and config.enabled: - llm_fn = global_config.get('llm_model_func') - # Build entity_types map for type-aware fuzzy matching. - # Use first non-empty type for entities with multiple occurrences. - entity_types: dict[str, str] = {} - for entity_name, entities in all_nodes.items(): - for entity_data in entities: - etype = entity_data.get('entity_type', '') - if etype: - entity_types[entity_name] = etype - break - - # Collect ALL entity names: from entities AND from relation endpoints - # This ensures relation endpoints like "EU Medicines Agency" get resolved - # against existing entities like "European Medicines Agency" - all_entity_names = set(all_nodes.keys()) - for src_id, tgt_id in all_edges: - all_entity_names.add(src_id) - all_entity_names.add(tgt_id) - - pre_resolution_map, confidence_map = await _build_pre_resolution_map( - list(all_entity_names), - entity_types, - entity_vdb, - llm_fn, - config, - ) - - # Cache pre-resolution aliases for future lookups (PostgreSQL-only) - # This ensures aliases discovered during batch processing are available - # for subsequent document ingestion without re-running resolution - db = getattr(knowledge_graph_inst, 'db', None) - if db is not None and pre_resolution_map: - workspace = global_config.get('workspace', '') - for alias, canonical in pre_resolution_map.items(): - # Don't cache self-references (entity → itself) - if alias.lower().strip() != canonical.lower().strip(): - try: - await store_alias( - alias=alias, - canonical=canonical, - method='pre_resolution', - confidence=confidence_map.get(alias, 1.0), - db=db, - workspace=workspace, - ) - except Exception as e: - logger.debug(f"Failed to cache pre-resolution alias '{alias}' → '{canonical}': {e}") - # ===== Phase 1: Process all entities concurrently ===== - log_message = f'Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})' + log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) - - # ===== Batch Prefetch: Load existing entity data in single query ===== - # Build list of entity names to prefetch (apply pre-resolution where applicable) - prefetch_entity_names = [] - for entity_name in all_nodes: - resolved_name = pre_resolution_map.get(entity_name, entity_name) - prefetch_entity_names.append(resolved_name) - - # Batch fetch existing nodes to avoid N+1 query pattern during parallel processing - prefetched_nodes: dict[str, dict] = {} - if prefetch_entity_names: - try: - prefetched_nodes = await knowledge_graph_inst.get_nodes_batch(prefetch_entity_names) - logger.debug(f'Prefetched {len(prefetched_nodes)}/{len(prefetch_entity_names)} existing entities for merge') - except Exception as e: - logger.warning(f'Batch entity prefetch failed: {e}. Falling back to individual fetches.') - prefetched_nodes = {} - - # Resolution map to track original→resolved entity names (e.g., "Dupixant"→"Dupixent") - # This will be used to remap edge endpoints in Phase 2 - entity_resolution_map: dict[str, str] = {} - resolution_map_lock = asyncio.Lock() + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _locked_process_entity_name(entity_name, entities): async with semaphore: # Check for cancellation before processing entity if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during entity merge') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity merge" + ) - workspace = global_config.get('workspace', '') - namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB' - async with get_storage_keyed_lock([entity_name], namespace=namespace, enable_logging=False): + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + async with get_storage_keyed_lock( + [entity_name], namespace=namespace, enable_logging=False + ): try: - logger.debug(f'Processing entity {entity_name}') - entity_data, resolved_from = await _merge_nodes_then_upsert( + logger.debug(f"Processing entity {entity_name}") + entity_data = await _merge_nodes_then_upsert( entity_name, entities, knowledge_graph_inst, @@ -3042,33 +2506,32 @@ async def merge_nodes_and_edges( pipeline_status_lock, llm_response_cache, entity_chunks_storage, - pre_resolution_map, - prefetched_nodes, ) - # Track resolution mapping for edge remapping in Phase 2 - if resolved_from is not None: - resolved_to = entity_data.get('entity_name', entity_name) - async with resolution_map_lock: - entity_resolution_map[resolved_from] = resolved_to - return entity_data except Exception as e: - error_msg = f'Error processing entity `{entity_name}`: {e}' + error_msg = f"Error processing entity `{entity_name}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception try: - if pipeline_status is not None and pipeline_status_lock is not None: + if ( + pipeline_status is not None + and pipeline_status_lock is not None + ): async with pipeline_status_lock: - pipeline_status['latest_message'] = error_msg - pipeline_status['history_messages'].append(error_msg) + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) except Exception as status_error: - logger.error(f'Failed to update pipeline status: {status_error}') + logger.error( + f"Failed to update pipeline status: {status_error}" + ) # Re-raise the original exception with a prefix - prefixed_exception = create_prefixed_exception(e, f'`{entity_name}`') + prefixed_exception = create_prefixed_exception( + e, f"`{entity_name}`" + ) raise prefixed_exception from e # Create entity processing tasks @@ -3080,7 +2543,9 @@ async def merge_nodes_and_edges( # Execute entity tasks with error handling processed_entities = [] if entity_tasks: - done, pending = await asyncio.wait(entity_tasks, return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait( + entity_tasks, return_when=asyncio.FIRST_EXCEPTION + ) first_exception = None processed_entities = [] @@ -3109,22 +2574,24 @@ async def merge_nodes_and_edges( raise first_exception # ===== Phase 2: Process all relationships concurrently ===== - log_message = f'Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})' + log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _locked_process_edges(edge_key, edges): async with semaphore: # Check for cancellation before processing edges if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during relation merge') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation merge" + ) - workspace = global_config.get('workspace', '') - namespace = f'{workspace}:GraphDB' if workspace else 'GraphDB' + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" sorted_edge_key = sorted([edge_key[0], edge_key[1]]) async with get_storage_keyed_lock( @@ -3135,7 +2602,7 @@ async def merge_nodes_and_edges( try: added_entities = [] # Track entities added during edge processing - logger.debug(f'Processing relation {sorted_edge_key}') + logger.debug(f"Processing relation {sorted_edge_key}") edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], @@ -3150,7 +2617,6 @@ async def merge_nodes_and_edges( added_entities, # Pass list to collect added entities relation_chunks_storage, entity_chunks_storage, # Add entity_chunks_storage parameter - entity_resolution_map, # Apply entity resolution to edge endpoints ) if edge_data is None: @@ -3159,50 +2625,33 @@ async def merge_nodes_and_edges( return edge_data, added_entities except Exception as e: - error_msg = f'Error processing relation `{sorted_edge_key}`: {e}' + error_msg = f"Error processing relation `{sorted_edge_key}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception try: - if pipeline_status is not None and pipeline_status_lock is not None: + if ( + pipeline_status is not None + and pipeline_status_lock is not None + ): async with pipeline_status_lock: - pipeline_status['latest_message'] = error_msg - pipeline_status['history_messages'].append(error_msg) + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) except Exception as status_error: - logger.error(f'Failed to update pipeline status: {status_error}') + logger.error( + f"Failed to update pipeline status: {status_error}" + ) # Re-raise the original exception with a prefix - prefixed_exception = create_prefixed_exception(e, f'{sorted_edge_key}') + prefixed_exception = create_prefixed_exception( + e, f"{sorted_edge_key}" + ) raise prefixed_exception from e # Create relationship processing tasks - # Apply pre_resolution_map to edge endpoints to prevent duplicates from relation extraction - # Key fixes: sort for lock ordering, filter self-loops, deduplicate merged edges - resolved_edges: dict[tuple[str, str], list] = {} - for edge_key, edges in all_edges.items(): - # Remap edge endpoints using pre-resolution map - # This catches cases like "EU Medicines Agency" → "European Medicines Agency" - resolved_src = pre_resolution_map.get(edge_key[0], edge_key[0]) or '' - resolved_tgt = pre_resolution_map.get(edge_key[1], edge_key[1]) or '' - - # Skip self-loops created by resolution (e.g., both endpoints resolve to same entity) - if resolved_src == resolved_tgt: - logger.debug(f'Skipping self-loop after resolution: {edge_key} → ({resolved_src}, {resolved_tgt})') - continue - - # Sort for consistent lock ordering (prevents deadlocks) - sorted_edge = sorted([resolved_src, resolved_tgt]) - resolved_edge_key: tuple[str, str] = (sorted_edge[0], sorted_edge[1]) - - # Merge edges that resolve to same key (deduplication) - if resolved_edge_key not in resolved_edges: - resolved_edges[resolved_edge_key] = [] - resolved_edges[resolved_edge_key].extend(edges) - - # Create tasks from deduplicated edges edge_tasks = [] - for resolved_edge_key, merged_edges in resolved_edges.items(): - task = asyncio.create_task(_locked_process_edges(resolved_edge_key, merged_edges)) + for edge_key, edges in all_edges.items(): + task = asyncio.create_task(_locked_process_edges(edge_key, edges)) edge_tasks.append(task) # Execute relationship tasks with error handling @@ -3210,7 +2659,9 @@ async def merge_nodes_and_edges( all_added_entities = [] if edge_tasks: - done, pending = await asyncio.wait(edge_tasks, return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait( + edge_tasks, return_when=asyncio.FIRST_EXCEPTION + ) first_exception = None @@ -3250,37 +2701,37 @@ async def merge_nodes_and_edges( # Add original processed entities for entity_data in processed_entities: - if entity_data and entity_data.get('entity_name'): - final_entity_names.add(entity_data['entity_name']) + if entity_data and entity_data.get("entity_name"): + final_entity_names.add(entity_data["entity_name"]) # Add entities that were added during relationship processing for added_entity in all_added_entities: - if added_entity and added_entity.get('entity_name'): - final_entity_names.add(added_entity['entity_name']) + if added_entity and added_entity.get("entity_name"): + final_entity_names.add(added_entity["entity_name"]) # Collect all relation pairs final_relation_pairs = set() for edge_data in processed_edges: if edge_data: - src_id = edge_data.get('src_id') - tgt_id = edge_data.get('tgt_id') + src_id = edge_data.get("src_id") + tgt_id = edge_data.get("tgt_id") if src_id and tgt_id: relation_pair = tuple(sorted([src_id, tgt_id])) final_relation_pairs.add(relation_pair) - log_message = f'Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}' + log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}" logger.info(log_message) async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) # Update storage if final_entity_names: await full_entities_storage.upsert( { doc_id: { - 'entity_names': list(final_entity_names), - 'count': len(final_entity_names), + "entity_names": list(final_entity_names), + "count": len(final_entity_names), } } ) @@ -3289,71 +2740,75 @@ async def merge_nodes_and_edges( await full_relations_storage.upsert( { doc_id: { - 'relation_pairs': [list(pair) for pair in final_relation_pairs], - 'count': len(final_relation_pairs), + "relation_pairs": [ + list(pair) for pair in final_relation_pairs + ], + "count": len(final_relation_pairs), } } ) logger.debug( - f'Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations' + f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations" ) except Exception as e: - logger.error(f'Failed to update entity-relation index for document {doc_id}: {e}') + logger.error( + f"Failed to update entity-relation index for document {doc_id}: {e}" + ) # Don't raise exception to avoid affecting main flow - log_message = f'Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations' + log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations" logger.info(log_message) async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def extract_entities( chunks: dict[str, TextChunkSchema], - global_config: dict[str, Any], - pipeline_status: dict | None = None, + global_config: dict[str, str], + pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, ) -> list: - if pipeline_status is None: - pipeline_status = {} - if pipeline_status_lock is None: - pipeline_status_lock = asyncio.Lock() # Check for cancellation at the start of entity extraction if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during entity extraction') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity extraction" + ) - use_llm_func: Callable[..., Any] = global_config['llm_model_func'] - entity_extract_max_gleaning = global_config['entity_extract_max_gleaning'] + use_llm_func: callable = global_config["llm_model_func"] + entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) # add language and example number params to prompt - language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE) - entity_types = global_config['addon_params'].get('entity_types', DEFAULT_ENTITY_TYPES) + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) + entity_types = global_config["addon_params"].get( + "entity_types", DEFAULT_ENTITY_TYPES + ) - examples = '\n'.join(PROMPTS['entity_extraction_examples']) + examples = "\n".join(PROMPTS["entity_extraction_examples"]) - example_context_base = { - 'tuple_delimiter': PROMPTS['DEFAULT_TUPLE_DELIMITER'], - 'completion_delimiter': PROMPTS['DEFAULT_COMPLETION_DELIMITER'], - 'entity_types': ', '.join(entity_types), - 'language': language, - } + example_context_base = dict( + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + entity_types=", ".join(entity_types), + language=language, + ) # add example's format examples = examples.format(**example_context_base) - context_base = { - 'tuple_delimiter': PROMPTS['DEFAULT_TUPLE_DELIMITER'], - 'completion_delimiter': PROMPTS['DEFAULT_COMPLETION_DELIMITER'], - 'entity_types': ','.join(entity_types), - 'examples': examples, - 'language': language, - } + context_base = dict( + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + entity_types=",".join(entity_types), + examples=examples, + language=language, + ) processed_chunks = 0 total_chunks = len(ordered_chunks) @@ -3369,35 +2824,39 @@ async def extract_entities( nonlocal processed_chunks chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] - content = chunk_dp.get('content', '') + content = chunk_dp["content"] # Get file path from chunk data or use default - file_path = chunk_dp.get('file_path') or 'unknown_source' + file_path = chunk_dp.get("file_path", "unknown_source") # Create cache keys collector for batch processing cache_keys_collector = [] # Get initial extraction - entity_extraction_system_prompt = PROMPTS['entity_extraction_system_prompt'].format( - **{**context_base, 'input_text': content} - ) - entity_extraction_user_prompt = PROMPTS['entity_extraction_user_prompt'].format( - **{**context_base, 'input_text': content} - ) - entity_continue_extraction_user_prompt = PROMPTS['entity_continue_extraction_user_prompt'].format( - **{**context_base, 'input_text': content} + # Format system prompt without input_text for each chunk (enables OpenAI prompt caching across chunks) + entity_extraction_system_prompt = PROMPTS[ + "entity_extraction_system_prompt" + ].format(**context_base) + # Format user prompts with input_text for each chunk + entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format( + **{**context_base, "input_text": content} ) + entity_continue_extraction_user_prompt = PROMPTS[ + "entity_continue_extraction_user_prompt" + ].format(**{**context_base, "input_text": content}) final_result, timestamp = await use_llm_func_with_cache( entity_extraction_user_prompt, use_llm_func, system_prompt=entity_extraction_system_prompt, llm_response_cache=llm_response_cache, - cache_type='extract', + cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) - history = pack_user_ass_to_openai_messages(entity_extraction_user_prompt, final_result) + history = pack_user_ass_to_openai_messages( + entity_extraction_user_prompt, final_result + ) # Process initial extraction with file path maybe_nodes, maybe_edges = await _process_extraction_result( @@ -3405,8 +2864,8 @@ async def extract_entities( chunk_key, timestamp, file_path, - tuple_delimiter=context_base['tuple_delimiter'], - completion_delimiter=context_base['completion_delimiter'], + tuple_delimiter=context_base["tuple_delimiter"], + completion_delimiter=context_base["completion_delimiter"], ) # Process additional gleaning results only 1 time when entity_extract_max_gleaning is greater than zero. @@ -3417,7 +2876,7 @@ async def extract_entities( system_prompt=entity_extraction_system_prompt, llm_response_cache=llm_response_cache, history_messages=history, - cache_type='extract', + cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) @@ -3428,16 +2887,18 @@ async def extract_entities( chunk_key, timestamp, file_path, - tuple_delimiter=context_base['tuple_delimiter'], - completion_delimiter=context_base['completion_delimiter'], + tuple_delimiter=context_base["tuple_delimiter"], + completion_delimiter=context_base["completion_delimiter"], ) # Merge results - compare description lengths to choose better version for entity_name, glean_entities in glean_nodes.items(): if entity_name in maybe_nodes: # Compare description lengths and keep the better one - original_desc_len = len(maybe_nodes[entity_name][0].get('description', '') or '') - glean_desc_len = len(glean_entities[0].get('description', '') or '') + original_desc_len = len( + maybe_nodes[entity_name][0].get("description", "") or "" + ) + glean_desc_len = len(glean_entities[0].get("description", "") or "") if glean_desc_len > original_desc_len: maybe_nodes[entity_name] = list(glean_entities) @@ -3446,18 +2907,20 @@ async def extract_entities( # New entity from gleaning stage maybe_nodes[entity_name] = list(glean_entities) - for edge_key, glean_edge_list in glean_edges.items(): + for edge_key, glean_edges in glean_edges.items(): if edge_key in maybe_edges: # Compare description lengths and keep the better one - original_desc_len = len(maybe_edges[edge_key][0].get('description', '') or '') - glean_desc_len = len(glean_edge_list[0].get('description', '') or '') + original_desc_len = len( + maybe_edges[edge_key][0].get("description", "") or "" + ) + glean_desc_len = len(glean_edges[0].get("description", "") or "") if glean_desc_len > original_desc_len: - maybe_edges[edge_key] = list(glean_edge_list) + maybe_edges[edge_key] = list(glean_edges) # Otherwise keep original version else: # New edge from gleaning stage - maybe_edges[edge_key] = list(glean_edge_list) + maybe_edges[edge_key] = list(glean_edges) # Batch update chunk's llm_cache_list with all collected cache keys if cache_keys_collector and text_chunks_storage: @@ -3465,24 +2928,24 @@ async def extract_entities( chunk_key, text_chunks_storage, cache_keys_collector, - 'entity_extraction', + "entity_extraction", ) processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) - log_message = f'Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}' + log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}" logger.info(log_message) if pipeline_status is not None: async with pipeline_status_lock: - pipeline_status['latest_message'] = log_message - pipeline_status['history_messages'].append(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) # Return the extracted nodes and edges for centralized processing return maybe_nodes, maybe_edges # Get max async tasks limit from global_config - chunk_max_async = global_config.get('llm_model_max_async', 4) + chunk_max_async = global_config.get("llm_model_max_async", 4) semaphore = asyncio.Semaphore(chunk_max_async) async def _process_with_semaphore(chunk): @@ -3490,8 +2953,10 @@ async def extract_entities( # Check for cancellation before processing chunk if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: - if pipeline_status.get('cancellation_requested', False): - raise PipelineCancelledException('User cancelled during chunk processing') + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during chunk processing" + ) try: return await _process_single_content(chunk) @@ -3536,7 +3001,7 @@ async def extract_entities( await asyncio.wait(pending) # Add progress prefix to the exception message - progress_prefix = f'C[{processed_chunks + 1}/{total_chunks}]' + progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]" # Re-raise the original exception with a prefix prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) @@ -3554,10 +3019,10 @@ async def kg_query( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - global_config: dict[str, Any], + global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, - chunks_vdb: BaseVectorStorage | None = None, + chunks_vdb: BaseVectorStorage = None, ) -> QueryResult | None: """ Execute knowledge graph query and return unified QueryResult object. @@ -3590,35 +3055,36 @@ async def kg_query( Returns None when no relevant context could be constructed for the query. """ if not query: - return QueryResult(content=PROMPTS['fail_response']) + return QueryResult(content=PROMPTS["fail_response"]) if query_param.model_func: use_model_func = query_param.model_func else: - use_model_func = global_config['llm_model_func'] + use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - llm_callable = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], use_model_func) - hl_keywords, ll_keywords = await get_keywords_from_query(query, query_param, global_config, hashing_kv) + hl_keywords, ll_keywords = await get_keywords_from_query( + query, query_param, global_config, hashing_kv + ) - logger.debug(f'High-level keywords: {hl_keywords}') - logger.debug(f'Low-level keywords: {ll_keywords}') + logger.debug(f"High-level keywords: {hl_keywords}") + logger.debug(f"Low-level keywords: {ll_keywords}") # Handle empty keywords - if ll_keywords == [] and query_param.mode in ['local', 'hybrid', 'mix']: - logger.warning('low_level_keywords is empty') - if hl_keywords == [] and query_param.mode in ['global', 'hybrid', 'mix']: - logger.warning('high_level_keywords is empty') + if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]: + logger.warning("low_level_keywords is empty") + if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]: + logger.warning("high_level_keywords is empty") if hl_keywords == [] and ll_keywords == []: if len(query) < 50: - logger.warning(f'Forced low_level_keywords to origin query: {query}') + logger.warning(f"Forced low_level_keywords to origin query: {query}") ll_keywords = [query] else: - return QueryResult(content=PROMPTS['fail_response']) + return QueryResult(content=PROMPTS["fail_response"]) - ll_keywords_str = ', '.join(ll_keywords) if ll_keywords else '' - hl_keywords_str = ', '.join(hl_keywords) if hl_keywords else '' + ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" + hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" # Build query context (unified interface) context_result = await _build_query_context( @@ -3634,44 +3100,41 @@ async def kg_query( ) if context_result is None: - logger.info('[kg_query] No query context could be built; returning no-result.') + logger.info("[kg_query] No query context could be built; returning no-result.") return None # Return different content based on query parameters if query_param.only_need_context and not query_param.only_need_prompt: - return QueryResult(content=context_result.context, raw_data=context_result.raw_data) + return QueryResult( + content=context_result.context, raw_data=context_result.raw_data + ) - user_prompt = f'\n\n{query_param.user_prompt}' if query_param.user_prompt else 'n/a' - response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs' - - # Build coverage guidance based on context sparsity - coverage_guidance = ( - PROMPTS['coverage_guidance_limited'] - if context_result.coverage_level == 'limited' - else PROMPTS['coverage_guidance_good'] + user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" ) - logger.debug(f'[kg_query] Coverage level: {context_result.coverage_level}') # Build system prompt - sys_prompt_temp = system_prompt if system_prompt else PROMPTS['rag_response'] + sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( response_type=response_type, user_prompt=user_prompt, context_data=context_result.context, - coverage_guidance=coverage_guidance, ) user_query = query if query_param.only_need_prompt: - prompt_content = '\n\n'.join([sys_prompt, '---User Query---', user_query]) + prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) return QueryResult(content=prompt_content, raw_data=context_result.raw_data) # Call LLM - tokenizer: Tokenizer = global_config['tokenizer'] + tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(query + sys_prompt)) logger.debug( - f'[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})' + f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" ) # Handle cache @@ -3686,18 +3149,22 @@ async def kg_query( query_param.max_total_tokens, hl_keywords_str, ll_keywords_str, - query_param.user_prompt or '', + query_param.user_prompt or "", query_param.enable_rerank, ) - cached_result = await handle_cache(hashing_kv, args_hash, user_query, query_param.mode, cache_type='query') + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" + ) if cached_result is not None: cached_response, _ = cached_result # Extract content, ignore timestamp - logger.info(' == LLM cache == Query cache hit, using cached response as query result') + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) response = cached_response else: - response = await llm_callable( + response = await use_model_func( user_query, system_prompt=sys_prompt, history_messages=query_param.conversation_history, @@ -3705,19 +3172,19 @@ async def kg_query( stream=query_param.stream, ) - if isinstance(response, str) and hashing_kv and hashing_kv.global_config.get('enable_llm_cache'): + if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { - 'mode': query_param.mode, - 'response_type': query_param.response_type, - 'top_k': query_param.top_k, - 'chunk_top_k': query_param.chunk_top_k, - 'max_entity_tokens': query_param.max_entity_tokens, - 'max_relation_tokens': query_param.max_relation_tokens, - 'max_total_tokens': query_param.max_total_tokens, - 'hl_keywords': hl_keywords_str, - 'll_keywords': ll_keywords_str, - 'user_prompt': query_param.user_prompt or '', - 'enable_rerank': query_param.enable_rerank, + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": hl_keywords_str, + "ll_keywords": ll_keywords_str, + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, } await save_to_cache( hashing_kv, @@ -3726,7 +3193,7 @@ async def kg_query( content=response, prompt=query, mode=query_param.mode, - cache_type='query', + cache_type="query", queryparam=queryparam_dict, ), ) @@ -3736,12 +3203,12 @@ async def kg_query( # Non-streaming response (string) if len(response) > len(sys_prompt): response = ( - response.replace(sys_prompt, '') - .replace('user', '') - .replace('model', '') - .replace(query, '') - .replace('', '') - .replace('', '') + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") .strip() ) @@ -3758,7 +3225,7 @@ async def kg_query( async def get_keywords_from_query( query: str, query_param: QueryParam, - global_config: dict[str, Any], + global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ @@ -3781,14 +3248,16 @@ async def get_keywords_from_query( return query_param.hl_keywords, query_param.ll_keywords # Extract keywords using extract_keywords_only function which already supports conversation history - hl_keywords, ll_keywords = await extract_keywords_only(query, query_param, global_config, hashing_kv) + hl_keywords, ll_keywords = await extract_keywords_only( + query, query_param, global_config, hashing_kv + ) return hl_keywords, ll_keywords async def extract_keywords_only( text: str, param: QueryParam, - global_config: dict[str, Any], + global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, ) -> tuple[list[str], list[str]]: """ @@ -3797,89 +3266,88 @@ async def extract_keywords_only( It ONLY extracts keywords (hl_keywords, ll_keywords). """ - # 1. Handle cache if needed - add cache type for keywords + # 1. Build the examples + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) + + # 2. Handle cache if needed - add cache type for keywords args_hash = compute_args_hash( param.mode, text, + language, + ) + cached_result = await handle_cache( + hashing_kv, args_hash, text, param.mode, cache_type="keywords" ) - cached_result = await handle_cache(hashing_kv, args_hash, text, param.mode, cache_type='keywords') if cached_result is not None: cached_response, _ = cached_result # Extract content, ignore timestamp try: keywords_data = json_repair.loads(cached_response) - if isinstance(keywords_data, dict): - return keywords_data.get('high_level_keywords', []), keywords_data.get('low_level_keywords', []) + return keywords_data.get("high_level_keywords", []), keywords_data.get( + "low_level_keywords", [] + ) except (json.JSONDecodeError, KeyError): - logger.warning('Invalid cache format for keywords, proceeding with extraction') - - # 2. Build the examples - examples = '\n'.join(PROMPTS['keywords_extraction_examples']) - - language = global_config['addon_params'].get('language', DEFAULT_SUMMARY_LANGUAGE) + logger.warning( + "Invalid cache format for keywords, proceeding with extraction" + ) # 3. Build the keyword-extraction prompt - kw_prompt = PROMPTS['keywords_extraction'].format( + kw_prompt = PROMPTS["keywords_extraction"].format( query=text, examples=examples, language=language, ) - tokenizer: Tokenizer = global_config['tokenizer'] + tokenizer: Tokenizer = global_config["tokenizer"] len_of_prompts = len(tokenizer.encode(kw_prompt)) - logger.debug(f'[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})') + logger.debug( + f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})" + ) # 4. Call the LLM for keyword extraction if param.model_func: use_model_func = param.model_func else: - use_model_func = global_config['llm_model_func'] + use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - llm_callable = cast(Callable[..., Awaitable[str]], use_model_func) - result = await llm_callable(kw_prompt, keyword_extraction=True) + result = await use_model_func(kw_prompt, keyword_extraction=True) # 5. Parse out JSON from the LLM response result = remove_think_tags(result) try: keywords_data = json_repair.loads(result) - if not keywords_data or not isinstance(keywords_data, dict): - logger.error('No JSON-like structure found in the LLM respond.') + if not keywords_data: + logger.error("No JSON-like structure found in the LLM respond.") return [], [] except json.JSONDecodeError as e: - logger.error(f'JSON parsing error: {e}') - logger.error(f'LLM respond: {result}') + logger.error(f"JSON parsing error: {e}") + logger.error(f"LLM respond: {result}") return [], [] - hl_keywords = keywords_data.get('high_level_keywords', []) - ll_keywords = keywords_data.get('low_level_keywords', []) - - # 5b. Extract years from query that LLM might have missed (defensive heuristic) - # This applies to ANY domain - financial reports, historical events, etc. - years_in_query = re.findall(r'\b(19|20)\d{2}\b', text) - for year in years_in_query: - if year not in ll_keywords: - ll_keywords.append(year) - logger.debug(f'Added year "{year}" to low_level_keywords from query') + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) # 6. Cache only the processed keywords with cache type if hl_keywords or ll_keywords: cache_data = { - 'high_level_keywords': hl_keywords, - 'low_level_keywords': ll_keywords, + "high_level_keywords": hl_keywords, + "low_level_keywords": ll_keywords, } - if hashing_kv and hashing_kv.global_config.get('enable_llm_cache'): + if hashing_kv.global_config.get("enable_llm_cache"): # Save to cache with query parameters queryparam_dict = { - 'mode': param.mode, - 'response_type': param.response_type, - 'top_k': param.top_k, - 'chunk_top_k': param.chunk_top_k, - 'max_entity_tokens': param.max_entity_tokens, - 'max_relation_tokens': param.max_relation_tokens, - 'max_total_tokens': param.max_total_tokens, - 'user_prompt': param.user_prompt or '', - 'enable_rerank': param.enable_rerank, + "mode": param.mode, + "response_type": param.response_type, + "top_k": param.top_k, + "chunk_top_k": param.chunk_top_k, + "max_entity_tokens": param.max_entity_tokens, + "max_relation_tokens": param.max_relation_tokens, + "max_total_tokens": param.max_total_tokens, + "user_prompt": param.user_prompt or "", + "enable_rerank": param.enable_rerank, } await save_to_cache( hashing_kv, @@ -3888,7 +3356,7 @@ async def extract_keywords_only( content=json.dumps(cache_data), prompt=text, mode=param.mode, - cache_type='keywords', + cache_type="keywords", queryparam=queryparam_dict, ), ) @@ -3900,8 +3368,7 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, - query_embedding: list[float] | None = None, - entity_keywords: list[str] | None = None, + query_embedding: list[float] = None, ) -> list[dict]: """ Retrieve text chunks from the vector database without reranking or truncation. @@ -3909,88 +3376,48 @@ async def _get_vector_context( This function performs vector search to find relevant text chunks for a query. Reranking and truncation will be handled later in the unified processing. - When reranking is enabled, retrieves more candidates (controlled by RETRIEVAL_MULTIPLIER) - to allow the reranker to surface hidden relevant chunks from beyond the initial top-k. - Args: query: The query string to search for chunks_vdb: Vector database containing document chunks query_param: Query parameters including chunk_top_k and ids query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls - entity_keywords: Optional list of entity names from query for entity-aware boosting Returns: List of text chunks with metadata """ try: - base_top_k = query_param.chunk_top_k or query_param.top_k + # Use chunk_top_k if specified, otherwise fall back to top_k + search_top_k = query_param.chunk_top_k or query_param.top_k cosine_threshold = chunks_vdb.cosine_better_than_threshold - # Two-stage retrieval: when reranking is enabled, retrieve more candidates - # so the reranker can surface hidden relevant chunks from beyond base_top_k - if query_param.enable_rerank: - search_top_k = base_top_k * DEFAULT_RETRIEVAL_MULTIPLIER - logger.debug( - f'Two-stage retrieval: {search_top_k} candidates (base={base_top_k}, x{DEFAULT_RETRIEVAL_MULTIPLIER})' - ) - else: - search_top_k = base_top_k - - # Three-stage retrieval: - # Stage 1: BM25+vector fusion (if available) - # Stage 1.5: Entity-aware boosting (if entity_keywords provided) - # Stage 2: Reranking (happens later in the pipeline) - - if hasattr(chunks_vdb, 'hybrid_search_with_entity_boost') and entity_keywords: - # Use entity-boosted hybrid search when entity keywords are available - results = await chunks_vdb.hybrid_search_with_entity_boost( - query, - top_k=search_top_k, - entity_keywords=entity_keywords, - query_embedding=query_embedding, - ) - search_method = 'hybrid+entity_boost' - elif hasattr(chunks_vdb, 'hybrid_search'): - results = await chunks_vdb.hybrid_search( - query, top_k=search_top_k, query_embedding=query_embedding - ) - search_method = 'hybrid (BM25+vector)' - else: - results = await chunks_vdb.query( - query, top_k=search_top_k, query_embedding=query_embedding - ) - search_method = 'vector' - + results = await chunks_vdb.query( + query, top_k=search_top_k, query_embedding=query_embedding + ) if not results: - logger.info(f'Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})') + logger.info( + f"Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" + ) return [] valid_chunks = [] - boosted_count = 0 for result in results: - if 'content' in result: + if "content" in result: chunk_with_metadata = { - 'content': result['content'], - 'created_at': result.get('created_at', None), - 'file_path': result.get('file_path', 'unknown_source'), - 'source_type': 'vector', # Mark the source type - 'chunk_id': result.get('id'), # Add chunk_id for deduplication + "content": result["content"], + "created_at": result.get("created_at", None), + "file_path": result.get("file_path", "unknown_source"), + "source_type": "vector", # Mark the source type + "chunk_id": result.get("id"), # Add chunk_id for deduplication } - if result.get('entity_boosted'): - chunk_with_metadata['entity_boosted'] = True - boosted_count += 1 valid_chunks.append(chunk_with_metadata) - boost_info = f', {boosted_count} entity-boosted' if boosted_count > 0 else '' - two_stage_info = f' [two-stage: {search_top_k}→{base_top_k}]' if query_param.enable_rerank else '' logger.info( - f'Vector retrieval: {len(valid_chunks)} chunks via {search_method} ' - f'(top_k:{search_top_k}{boost_info}{two_stage_info})' + f"Naive query: {len(valid_chunks)} chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" ) return valid_chunks except Exception as e: - logger.error(f'Error in _get_vector_context: {e}') + logger.error(f"Error in _get_vector_context: {e}") return [] @@ -4003,7 +3430,7 @@ async def _perform_kg_search( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - chunks_vdb: BaseVectorStorage | None = None, + chunks_vdb: BaseVectorStorage = None, ) -> dict[str, Any]: """ Pure search logic that retrieves raw entities, relations, and vector chunks. @@ -4023,18 +3450,26 @@ async def _perform_kg_search( # Track chunk sources and metadata for final logging chunk_tracking = {} # chunk_id -> {source, frequency, order} - # Pre-compute query embedding once for all vector operations (with caching) - kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD) + # Pre-compute query embedding once for all vector operations + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) query_embedding = None - if query and (kg_chunk_pick_method == 'VECTOR' or chunks_vdb): + if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): actual_embedding_func = text_chunks_db.embedding_func if actual_embedding_func: - query_embedding = await get_cached_query_embedding(query, actual_embedding_func) - if query_embedding is not None: - logger.debug('Pre-computed query embedding for all vector operations') + try: + query_embedding = await actual_embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug("Pre-computed query embedding for all vector operations") + except Exception as e: + logger.warning(f"Failed to pre-compute query embedding: {e}") + query_embedding = None # Handle local and global modes - if query_param.mode == 'local' and len(ll_keywords) > 0: + if query_param.mode == "local" and len(ll_keywords) > 0: local_entities, local_relations = await _get_node_data( ll_keywords, knowledge_graph_inst, @@ -4042,7 +3477,7 @@ async def _perform_kg_search( query_param, ) - elif query_param.mode == 'global' and len(hl_keywords) > 0: + elif query_param.mode == "global" and len(hl_keywords) > 0: global_relations, global_entities = await _get_edge_data( hl_keywords, knowledge_graph_inst, @@ -4067,33 +3502,24 @@ async def _perform_kg_search( ) # Get vector chunks for mix mode - if query_param.mode == 'mix' and chunks_vdb: - logger.info(f'[MIX DEBUG] Starting vector search for query: {query[:60]}...') - # Parse ll_keywords into list for entity-aware boosting - entity_keywords_list = [kw.strip() for kw in ll_keywords.split(',') if kw.strip()] if ll_keywords else None + if query_param.mode == "mix" and chunks_vdb: vector_chunks = await _get_vector_context( query, chunks_vdb, query_param, query_embedding, - entity_keywords=entity_keywords_list, ) - logger.info(f'[MIX DEBUG] Vector search returned {len(vector_chunks)} chunks') - if not vector_chunks: - logger.warning( - f'[MIX DEBUG] ⚠️ NO VECTOR CHUNKS! chunk_top_k={query_param.chunk_top_k}, top_k={query_param.top_k}' - ) # Track vector chunks with source metadata for i, chunk in enumerate(vector_chunks): - chunk_id = chunk.get('chunk_id') or chunk.get('id') + chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id: chunk_tracking[chunk_id] = { - 'source': 'C', - 'frequency': 1, # Vector chunks always have frequency 1 - 'order': i + 1, # 1-based order in vector search results + "source": "C", + "frequency": 1, # Vector chunks always have frequency 1 + "order": i + 1, # 1-based order in vector search results } else: - logger.warning(f'Vector chunk missing chunk_id: {chunk}') + logger.warning(f"Vector chunk missing chunk_id: {chunk}") # Round-robin merge entities final_entities = [] @@ -4103,7 +3529,7 @@ async def _perform_kg_search( # First from local if i < len(local_entities): entity = local_entities[i] - entity_name = entity.get('entity_name') + entity_name = entity.get("entity_name") if entity_name and entity_name not in seen_entities: final_entities.append(entity) seen_entities.add(entity_name) @@ -4111,7 +3537,7 @@ async def _perform_kg_search( # Then from global if i < len(global_entities): entity = global_entities[i] - entity_name = entity.get('entity_name') + entity_name = entity.get("entity_name") if entity_name and entity_name not in seen_entities: final_entities.append(entity) seen_entities.add(entity_name) @@ -4125,10 +3551,12 @@ async def _perform_kg_search( if i < len(local_relations): relation = local_relations[i] # Build relation unique identifier - if 'src_tgt' in relation: - rel_key = tuple(sorted(relation['src_tgt'])) + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) else: - rel_key = tuple(sorted([relation.get('src_id'), relation.get('tgt_id')])) + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) if rel_key not in seen_relations: final_relations.append(relation) @@ -4138,211 +3566,64 @@ async def _perform_kg_search( if i < len(global_relations): relation = global_relations[i] # Build relation unique identifier - if 'src_tgt' in relation: - rel_key = tuple(sorted(relation['src_tgt'])) + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) else: - rel_key = tuple(sorted([relation.get('src_id'), relation.get('tgt_id')])) + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) if rel_key not in seen_relations: final_relations.append(relation) seen_relations.add(rel_key) logger.info( - f'Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks' + f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks" ) return { - 'final_entities': final_entities, - 'final_relations': final_relations, - 'vector_chunks': vector_chunks, - 'chunk_tracking': chunk_tracking, - 'query_embedding': query_embedding, + "final_entities": final_entities, + "final_relations": final_relations, + "vector_chunks": vector_chunks, + "chunk_tracking": chunk_tracking, + "query_embedding": query_embedding, } -def _check_topic_connectivity( - entities_context: list[dict], - relations_context: list[dict], - min_relationship_density: float = 0.3, - min_entity_coverage: float = 0.5, - min_doc_coverage: float = 0.15, # Lower threshold for document connectivity -) -> tuple[bool, str]: - """ - Check if retrieved entities are topically connected. - - Connectivity is established if EITHER: - 1. Entities come from the same source document (file_path) - uses min_doc_coverage threshold - 2. Entities are connected via graph relationships (BFS) - uses min_entity_coverage threshold - - The document coverage threshold (min_doc_coverage) is lower than graph coverage because - vector search often pulls semantically similar entities from multiple documents. - Even 15% from a single document indicates strong topical relevance. - - Args: - entities_context: List of entity dictionaries from retrieval - relations_context: List of relationship dictionaries from retrieval - min_relationship_density: Minimum ratio of relationships to entities (0.0-1.0) - min_entity_coverage: Minimum ratio of entities in largest graph component (0.0-1.0) - min_doc_coverage: Minimum ratio of entities from a single document (0.0-1.0) - - Returns: - (is_connected, reason) where: - - is_connected: True if topics form a connected graph or share document - - reason: Explanation if connectivity check failed - """ - if not entities_context: - return True, '' # No entities = let existing empty-context logic handle it - - # Build set of entity names and track file_path for each (case-insensitive) - entity_names = set() - entity_files: dict[str, str] = {} # entity_name -> primary file_path - - for e in entities_context: - name = e.get('entity', e.get('entity_name', '')).lower() - if not name: - continue - entity_names.add(name) - - # Extract primary file_path (before if multiple) - file_path = e.get('file_path', '') - if file_path: - primary_file = file_path.split('')[0].strip() - if primary_file: - entity_files[name] = primary_file - - if not entity_names: - return True, '' # Can't check without entity names - - # Check 1: Document connectivity - # If entities from the same document meet the coverage threshold, pass - file_to_entities: dict[str, set[str]] = {} - for name, file_path in entity_files.items(): - file_to_entities.setdefault(file_path, set()).add(name) - - # Debug: Log document distribution - if file_to_entities: - doc_summary = ', '.join( - f'"{fp}": {len(ents)}' for fp, ents in sorted(file_to_entities.items(), key=lambda x: -len(x[1]))[:5] - ) - logger.info( - f'Topic connectivity: Document distribution (top 5): [{doc_summary}] total_entities={len(entity_names)}' - ) - else: - logger.info(f'Topic connectivity: No file_path data available for {len(entity_names)} entities') - - for file_path, file_entities in file_to_entities.items(): - doc_coverage = len(file_entities) / len(entity_names) - if doc_coverage >= min_doc_coverage: - logger.info( - f'Topic connectivity: PASS (document) - {len(file_entities)}/{len(entity_names)} ' - f'entities from "{file_path}" ({doc_coverage:.0%} coverage, threshold={min_doc_coverage:.0%})' - ) - return True, '' - - # Check 2: Graph connectivity (original logic) - if not relations_context: - # Log document distribution for debugging - doc_summary = ', '.join(f'{fp}: {len(ents)}' for fp, ents in file_to_entities.items()) - logger.info(f'Topic connectivity: FAIL - no relations, docs=[{doc_summary}]') - return False, 'Retrieved entities have no connecting relationships or common document' - - # Build adjacency list from relationships - # Only include edges where BOTH endpoints are in our entity set - adjacency: dict[str, set[str]] = {name: set() for name in entity_names} - edges_in_subgraph = 0 - - for rel in relations_context: - src = (rel.get('src_id') or rel.get('source') or rel.get('entity1') or '').lower() - tgt = (rel.get('tgt_id') or rel.get('target') or rel.get('entity2') or '').lower() - - # Only count edges where BOTH endpoints are in our retrieved entities - if src in entity_names and tgt in entity_names and src != tgt: - adjacency[src].add(tgt) - adjacency[tgt].add(src) - edges_in_subgraph += 1 - - # Check if we have enough edges connecting entities to each other - if edges_in_subgraph == 0: - doc_summary = ', '.join(f'{fp}: {len(ents)}' for fp, ents in file_to_entities.items()) - logger.info( - f'Topic connectivity: FAIL - {len(relations_context)} relations but 0 connect ' - f'retrieved entities to each other, docs=[{doc_summary}]' - ) - return False, 'Retrieved entities are not connected to each other by relationships or common document' - - # Use BFS to find connected components - visited = set() - components = [] - - for start_node in entity_names: - if start_node in visited: - continue - - # BFS from this node - component = set() - queue = [start_node] - while queue: - node = queue.pop(0) - if node in visited: - continue - visited.add(node) - component.add(node) - for neighbor in adjacency[node]: - if neighbor not in visited: - queue.append(neighbor) - - components.append(component) - - # Check component coverage - largest_component = max(components, key=len) if components else set() - coverage = len(largest_component) / len(entity_names) if entity_names else 0 - - logger.info( - f'Topic connectivity: {len(entity_names)} entities, {edges_in_subgraph} connecting edges, ' - f'{len(components)} components, largest={len(largest_component)} ({coverage:.0%} coverage)' - ) - - if coverage < min_entity_coverage: - logger.info(f'Topic connectivity: FAIL - graph coverage {coverage:.0%} < threshold {min_entity_coverage:.0%}') - return False, f'Entities form {len(components)} disconnected clusters (largest covers {coverage:.0%})' - - return True, '' - - async def _apply_token_truncation( search_result: dict[str, Any], query_param: QueryParam, - global_config: dict[str, Any], + global_config: dict[str, str], ) -> dict[str, Any]: """ Apply token-based truncation to entities and relations for LLM efficiency. """ - tokenizer = global_config.get('tokenizer') + tokenizer = global_config.get("tokenizer") if not tokenizer: - logger.warning('No tokenizer found, skipping truncation') + logger.warning("No tokenizer found, skipping truncation") return { - 'entities_context': [], - 'relations_context': [], - 'filtered_entities': search_result['final_entities'], - 'filtered_relations': search_result['final_relations'], - 'entity_id_to_original': {}, - 'relation_id_to_original': {}, + "entities_context": [], + "relations_context": [], + "filtered_entities": search_result["final_entities"], + "filtered_relations": search_result["final_relations"], + "entity_id_to_original": {}, + "relation_id_to_original": {}, } # Get token limits from query_param with fallbacks max_entity_tokens = getattr( query_param, - 'max_entity_tokens', - global_config.get('max_entity_tokens', DEFAULT_MAX_ENTITY_TOKENS), + "max_entity_tokens", + global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS), ) max_relation_tokens = getattr( query_param, - 'max_relation_tokens', - global_config.get('max_relation_tokens', DEFAULT_MAX_RELATION_TOKENS), + "max_relation_tokens", + global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS), ) - final_entities = search_result['final_entities'] - final_relations = search_result['final_relations'] + final_entities = search_result["final_entities"] + final_relations = search_result["final_relations"] # Create mappings from entity/relation identifiers to original data entity_id_to_original = {} @@ -4350,37 +3631,37 @@ async def _apply_token_truncation( # Generate entities context for truncation entities_context = [] - for _i, entity in enumerate(final_entities): - entity_name = entity['entity_name'] - created_at = entity.get('created_at', 'UNKNOWN') + for i, entity in enumerate(final_entities): + entity_name = entity["entity_name"] + created_at = entity.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): - created_at = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(created_at)) + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Store mapping from entity name to original data entity_id_to_original[entity_name] = entity entities_context.append( { - 'entity': entity_name, - 'type': entity.get('entity_type', 'UNKNOWN'), - 'description': entity.get('description', 'UNKNOWN'), - 'created_at': created_at, - 'file_path': entity.get('file_path', 'unknown_source'), + "entity": entity_name, + "type": entity.get("entity_type", "UNKNOWN"), + "description": entity.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": entity.get("file_path", "unknown_source"), } ) # Generate relations context for truncation relations_context = [] - for _i, relation in enumerate(final_relations): - created_at = relation.get('created_at', 'UNKNOWN') + for i, relation in enumerate(final_relations): + created_at = relation.get("created_at", "UNKNOWN") if isinstance(created_at, (int, float)): - created_at = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(created_at)) + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) # Handle different relation data formats - if 'src_tgt' in relation: - entity1, entity2 = relation['src_tgt'] + if "src_tgt" in relation: + entity1, entity2 = relation["src_tgt"] else: - entity1, entity2 = relation.get('src_id'), relation.get('tgt_id') + entity1, entity2 = relation.get("src_id"), relation.get("tgt_id") # Store mapping from relation pair to original data relation_key = (entity1, entity2) @@ -4388,103 +3669,67 @@ async def _apply_token_truncation( relations_context.append( { - 'entity1': entity1, - 'entity2': entity2, - 'description': relation.get('description', 'UNKNOWN'), - 'created_at': created_at, - 'file_path': relation.get('file_path', 'unknown_source'), + "entity1": entity1, + "entity2": entity2, + "description": relation.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": relation.get("file_path", "unknown_source"), } ) - original_entity_count = len(entities_context) - original_relation_count = len(relations_context) + logger.debug( + f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations" + ) - # Check if truncation is disabled - disable_truncation = getattr(query_param, 'disable_truncation', False) - - # Track truncated counts for logging - truncated_entity_count = original_entity_count - truncated_relation_count = original_relation_count - - # Apply token-based truncation (safety ceiling, not aggressive limiting) - # Top-k already limits retrieval count; this is a fallback for very long descriptions + # Apply token-based truncation if entities_context: - # Remove file_path and created_at for token calculation (metadata not sent to LLM) + # Remove file_path and created_at for token calculation entities_context_for_truncation = [] for entity in entities_context: entity_copy = entity.copy() - entity_copy.pop('file_path', None) - entity_copy.pop('created_at', None) + entity_copy.pop("file_path", None) + entity_copy.pop("created_at", None) entities_context_for_truncation.append(entity_copy) - truncated_entities = truncate_list_by_token_size( + entities_context = truncate_list_by_token_size( entities_context_for_truncation, - key=lambda x: '\n'.join(json.dumps(item, ensure_ascii=False) for item in [x]), + key=lambda x: "\n".join( + json.dumps(item, ensure_ascii=False) for item in [x] + ), max_token_size=max_entity_tokens, tokenizer=tokenizer, ) - # Restore file_path from original entities_context (needed for connectivity check) - entity_name_to_file_path = {e['entity']: e.get('file_path', '') for e in entities_context} - for entity in truncated_entities: - entity['file_path'] = entity_name_to_file_path.get(entity['entity'], '') - - truncated_entity_count = len(truncated_entities) - if not disable_truncation: - entities_context = truncated_entities if relations_context: # Remove file_path and created_at for token calculation relations_context_for_truncation = [] for relation in relations_context: relation_copy = relation.copy() - relation_copy.pop('file_path', None) - relation_copy.pop('created_at', None) + relation_copy.pop("file_path", None) + relation_copy.pop("created_at", None) relations_context_for_truncation.append(relation_copy) - truncated_relations = truncate_list_by_token_size( + relations_context = truncate_list_by_token_size( relations_context_for_truncation, - key=lambda x: '\n'.join(json.dumps(item, ensure_ascii=False) for item in [x]), + key=lambda x: "\n".join( + json.dumps(item, ensure_ascii=False) for item in [x] + ), max_token_size=max_relation_tokens, tokenizer=tokenizer, ) - # Restore file_path from original relations_context (for consistency) - relation_key_to_file_path = {(r['entity1'], r['entity2']): r.get('file_path', '') for r in relations_context} - for relation in truncated_relations: - relation['file_path'] = relation_key_to_file_path.get((relation['entity1'], relation['entity2']), '') - truncated_relation_count = len(truncated_relations) - if not disable_truncation: - relations_context = truncated_relations - - # Calculate how many would be dropped - entities_would_drop = original_entity_count - truncated_entity_count - relations_would_drop = original_relation_count - truncated_relation_count - - # Log if token ceiling was/would be hit - if entities_would_drop > 0 or relations_would_drop > 0: - if disable_truncation: - logger.warning( - f'Token ceiling exceeded (truncation disabled): would have dropped ' - f'{entities_would_drop} entities, {relations_would_drop} relations. ' - f'Keeping all {original_entity_count} entities, {original_relation_count} relations.' - ) - else: - logger.warning( - f'Token ceiling hit: dropped {entities_would_drop} entities (limit={max_entity_tokens} tokens), ' - f'{relations_would_drop} relations (limit={max_relation_tokens} tokens). ' - f'Increase token limits or set disable_truncation=True.' - ) - - logger.info(f'Context: {len(entities_context)} entities, {len(relations_context)} relations') + logger.info( + f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations" + ) # Create filtered original data based on truncated context filtered_entities = [] filtered_entity_id_to_original = {} if entities_context: - final_entity_names = {e['entity'] for e in entities_context} + final_entity_names = {e["entity"] for e in entities_context} seen_nodes = set() for entity in final_entities: - name = entity.get('entity_name') + name = entity.get("entity_name") if name in final_entity_names and name not in seen_nodes: filtered_entities.append(entity) filtered_entity_id_to_original[name] = entity @@ -4493,12 +3738,12 @@ async def _apply_token_truncation( filtered_relations = [] filtered_relation_id_to_original = {} if relations_context: - final_relation_pairs = {(r['entity1'], r['entity2']) for r in relations_context} + final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} seen_edges = set() for relation in final_relations: - src, tgt = relation.get('src_id'), relation.get('tgt_id') + src, tgt = relation.get("src_id"), relation.get("tgt_id") if src is None or tgt is None: - src, tgt = relation.get('src_tgt', (None, None)) + src, tgt = relation.get("src_tgt", (None, None)) pair = (src, tgt) if pair in final_relation_pairs and pair not in seen_edges: @@ -4507,12 +3752,12 @@ async def _apply_token_truncation( seen_edges.add(pair) return { - 'entities_context': entities_context, - 'relations_context': relations_context, - 'filtered_entities': filtered_entities, - 'filtered_relations': filtered_relations, - 'entity_id_to_original': filtered_entity_id_to_original, - 'relation_id_to_original': filtered_relation_id_to_original, + "entities_context": entities_context, + "relations_context": relations_context, + "filtered_entities": filtered_entities, + "filtered_relations": filtered_relations, + "entity_id_to_original": filtered_entity_id_to_original, + "relation_id_to_original": filtered_relation_id_to_original, } @@ -4520,21 +3765,17 @@ async def _merge_all_chunks( filtered_entities: list[dict], filtered_relations: list[dict], vector_chunks: list[dict], - query: str = '', - knowledge_graph_inst: BaseGraphStorage | None = None, - text_chunks_db: BaseKVStorage | None = None, - query_param: QueryParam | None = None, - chunks_vdb: BaseVectorStorage | None = None, - chunk_tracking: dict | None = None, - query_embedding: list[float] | None = None, + query: str = "", + knowledge_graph_inst: BaseGraphStorage = None, + text_chunks_db: BaseKVStorage = None, + query_param: QueryParam = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, + query_embedding: list[float] = None, ) -> list[dict]: """ Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks. """ - if query_param is None: - raise ValueError('query_param is required for merging chunks') - if knowledge_graph_inst is None or chunks_vdb is None: - raise ValueError('knowledge_graph_inst and chunks_vdb are required for chunk merging') if chunk_tracking is None: chunk_tracking = {} @@ -4576,47 +3817,47 @@ async def _merge_all_chunks( # Add from vector chunks first (Naive mode) if i < len(vector_chunks): chunk = vector_chunks[i] - chunk_id = chunk.get('chunk_id') or chunk.get('id') + chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { - 'content': chunk['content'], - 'file_path': chunk.get('file_path', 'unknown_source'), - 'chunk_id': chunk_id, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, } ) # Add from entity chunks (Local mode) if i < len(entity_chunks): chunk = entity_chunks[i] - chunk_id = chunk.get('chunk_id') or chunk.get('id') + chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { - 'content': chunk['content'], - 'file_path': chunk.get('file_path', 'unknown_source'), - 'chunk_id': chunk_id, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, } ) # Add from relation chunks (Global mode) if i < len(relation_chunks): chunk = relation_chunks[i] - chunk_id = chunk.get('chunk_id') or chunk.get('id') + chunk_id = chunk.get("chunk_id") or chunk.get("id") if chunk_id and chunk_id not in seen_chunk_ids: seen_chunk_ids.add(chunk_id) merged_chunks.append( { - 'content': chunk['content'], - 'file_path': chunk.get('file_path', 'unknown_source'), - 'chunk_id': chunk_id, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, } ) logger.info( - f'Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})' + f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})" ) return merged_chunks @@ -4628,18 +3869,18 @@ async def _build_context_str( merged_chunks: list[dict], query: str, query_param: QueryParam, - global_config: dict[str, Any], - chunk_tracking: dict | None = None, - entity_id_to_original: dict | None = None, - relation_id_to_original: dict | None = None, + global_config: dict[str, str], + chunk_tracking: dict = None, + entity_id_to_original: dict = None, + relation_id_to_original: dict = None, ) -> tuple[str, dict[str, Any]]: """ Build the final LLM context string with token processing. This includes dynamic token calculation and final chunk truncation. """ - tokenizer = global_config.get('tokenizer') + tokenizer = global_config.get("tokenizer") if not tokenizer: - logger.error('Missing tokenizer, cannot build LLM context') + logger.error("Missing tokenizer, cannot build LLM context") # Return empty raw data structure when no tokenizer empty_raw_data = convert_to_user_format( [], @@ -4648,49 +3889,64 @@ async def _build_context_str( [], query_param.mode, ) - empty_raw_data['status'] = 'failure' - empty_raw_data['message'] = 'Missing tokenizer, cannot build LLM context.' - return '', empty_raw_data + empty_raw_data["status"] = "failure" + empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context." + return "", empty_raw_data # Get token limits max_total_tokens = getattr( query_param, - 'max_total_tokens', - global_config.get('max_total_tokens', DEFAULT_MAX_TOTAL_TOKENS), + "max_total_tokens", + global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), ) # Get the system prompt template from PROMPTS or global_config - sys_prompt_template = global_config.get('system_prompt_template', PROMPTS['rag_response']) + sys_prompt_template = global_config.get( + "system_prompt_template", PROMPTS["rag_response"] + ) - kg_context_template = PROMPTS['kg_query_context'] - user_prompt = query_param.user_prompt if query_param.user_prompt else '' - response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs' + kg_context_template = PROMPTS["kg_query_context"] + user_prompt = query_param.user_prompt if query_param.user_prompt else "" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" + ) - entities_str = '\n'.join(json.dumps(entity, ensure_ascii=False) for entity in entities_context) - relations_str = '\n'.join(json.dumps(relation, ensure_ascii=False) for relation in relations_context) + entities_str = "\n".join( + json.dumps(entity, ensure_ascii=False) for entity in entities_context + ) + relations_str = "\n".join( + json.dumps(relation, ensure_ascii=False) for relation in relations_context + ) # Calculate preliminary kg context tokens pre_kg_context = kg_context_template.format( entities_str=entities_str, relations_str=relations_str, - text_chunks_str='', - reference_list_str='', + text_chunks_str="", + reference_list_str="", ) kg_context_tokens = len(tokenizer.encode(pre_kg_context)) # Calculate preliminary system prompt tokens pre_sys_prompt = sys_prompt_template.format( - context_data='', # Empty for overhead calculation + context_data="", # Empty for overhead calculation response_type=response_type, user_prompt=user_prompt, - coverage_guidance='', # Empty for overhead calculation ) sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) # Calculate available tokens for text chunks query_tokens = len(tokenizer.encode(query)) buffer_tokens = 200 # reserved for reference list and safety buffer - available_chunk_tokens = max_total_tokens - (sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens) + available_chunk_tokens = max_total_tokens - ( + sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens + ) + + logger.debug( + f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" + ) # Apply token truncation to chunks using the dynamic limit truncated_chunks = await process_chunks_unified( @@ -4703,26 +3959,32 @@ async def _build_context_str( ) # Generate reference list from truncated chunks using the new common function - reference_list, truncated_chunks = generate_reference_list_from_chunks(truncated_chunks) + reference_list, truncated_chunks = generate_reference_list_from_chunks( + truncated_chunks + ) # Rebuild chunks_context with truncated chunks # The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic chunks_context = [] - for _i, chunk in enumerate(truncated_chunks): + for i, chunk in enumerate(truncated_chunks): chunks_context.append( { - 'reference_id': chunk['reference_id'], - 'content': chunk['content'], + "reference_id": chunk["reference_id"], + "content": chunk["content"], } ) - text_units_str = '\n'.join(json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context) - reference_list_str = '\n'.join( - f'[{ref["reference_id"]}] {ref["file_path"]}' for ref in reference_list if ref['reference_id'] + text_units_str = "\n".join( + json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context + ) + reference_list_str = "\n".join( + f"[{ref['reference_id']}] {ref['file_path']}" + for ref in reference_list + if ref["reference_id"] ) logger.info( - f'Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks' + f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks" ) # not necessary to use LLM to generate a response @@ -4735,27 +3997,27 @@ async def _build_context_str( [], query_param.mode, ) - empty_raw_data['status'] = 'failure' - empty_raw_data['message'] = 'Query returned empty dataset.' - return '', empty_raw_data + empty_raw_data["status"] = "failure" + empty_raw_data["message"] = "Query returned empty dataset." + return "", empty_raw_data # output chunks tracking infomations # format: / (e.g., E5/2 R2/1 C1/1) if truncated_chunks and chunk_tracking: chunk_tracking_log = [] for chunk in truncated_chunks: - chunk_id = chunk.get('chunk_id') + chunk_id = chunk.get("chunk_id") if chunk_id and chunk_id in chunk_tracking: tracking_info = chunk_tracking[chunk_id] - source = tracking_info['source'] - frequency = tracking_info['frequency'] - order = tracking_info['order'] - chunk_tracking_log.append(f'{source}{frequency}/{order}') + source = tracking_info["source"] + frequency = tracking_info["frequency"] + order = tracking_info["order"] + chunk_tracking_log.append(f"{source}{frequency}/{order}") else: - chunk_tracking_log.append('?0/0') + chunk_tracking_log.append("?0/0") if chunk_tracking_log: - logger.info(f'Final chunks S+F/O: {" ".join(chunk_tracking_log)}') + logger.info(f"Final chunks S+F/O: {' '.join(chunk_tracking_log)}") result = kg_context_template.format( entities_str=entities_str, @@ -4766,7 +4028,7 @@ async def _build_context_str( # Always return both context and complete data structure (unified approach) logger.debug( - f'[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks' + f"[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" ) final_data = convert_to_user_format( entities_context, @@ -4778,7 +4040,7 @@ async def _build_context_str( relation_id_to_original, ) logger.debug( - f'[_build_context_str] Final data after conversion: {len(final_data.get("entities", []))} entities, {len(final_data.get("relationships", []))} relationships, {len(final_data.get("chunks", []))} chunks' + f"[_build_context_str] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" ) return result, final_data @@ -4793,7 +4055,7 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - chunks_vdb: BaseVectorStorage | None = None, + chunks_vdb: BaseVectorStorage = None, ) -> QueryContextResult | None: """ Main query context building function using the new 4-stage architecture: @@ -4803,7 +4065,7 @@ async def _build_query_context( """ if not query: - logger.warning('Query is empty, skipping context building') + logger.warning("Query is empty, skipping context building") return None # Stage 1: Pure search @@ -4819,11 +4081,11 @@ async def _build_query_context( chunks_vdb, ) - if not search_result['final_entities'] and not search_result['final_relations']: - if query_param.mode != 'mix': + if not search_result["final_entities"] and not search_result["final_relations"]: + if query_param.mode != "mix": return None else: - if not search_result['chunk_tracking']: + if not search_result["chunk_tracking"]: return None # Stage 2: Apply token truncation for LLM efficiency @@ -4833,91 +4095,75 @@ async def _build_query_context( text_chunks_db.global_config, ) - # Stage 2.5: Check topic connectivity - # Skip for: naive (no graph), bypass (no retrieval) - # This prevents hallucination when querying unrelated topics like "diabetes + renewable energy" - if query_param.check_topic_connectivity and query_param.mode not in ('naive', 'bypass'): - is_connected, reason = _check_topic_connectivity( - truncation_result['entities_context'], - truncation_result['relations_context'], - min_relationship_density=query_param.min_relationship_density, - min_entity_coverage=query_param.min_entity_coverage, - ) - if not is_connected: - logger.info(f'Topic connectivity check failed: {reason}') - return None # Return None to trigger "no context" response - # Stage 3: Merge chunks using filtered entities/relations merged_chunks = await _merge_all_chunks( - filtered_entities=truncation_result['filtered_entities'], - filtered_relations=truncation_result['filtered_relations'], - vector_chunks=search_result['vector_chunks'], + filtered_entities=truncation_result["filtered_entities"], + filtered_relations=truncation_result["filtered_relations"], + vector_chunks=search_result["vector_chunks"], query=query, knowledge_graph_inst=knowledge_graph_inst, text_chunks_db=text_chunks_db, query_param=query_param, chunks_vdb=chunks_vdb, - chunk_tracking=search_result['chunk_tracking'], - query_embedding=search_result['query_embedding'], + chunk_tracking=search_result["chunk_tracking"], + query_embedding=search_result["query_embedding"], ) - if not merged_chunks and not truncation_result['entities_context'] and not truncation_result['relations_context']: + if ( + not merged_chunks + and not truncation_result["entities_context"] + and not truncation_result["relations_context"] + ): return None # Stage 4: Build final LLM context with dynamic token processing # _build_context_str now always returns tuple[str, dict] context, raw_data = await _build_context_str( - entities_context=truncation_result['entities_context'], - relations_context=truncation_result['relations_context'], + entities_context=truncation_result["entities_context"], + relations_context=truncation_result["relations_context"], merged_chunks=merged_chunks, query=query, query_param=query_param, global_config=text_chunks_db.global_config, - chunk_tracking=search_result['chunk_tracking'], - entity_id_to_original=truncation_result['entity_id_to_original'], - relation_id_to_original=truncation_result['relation_id_to_original'], + chunk_tracking=search_result["chunk_tracking"], + entity_id_to_original=truncation_result["entity_id_to_original"], + relation_id_to_original=truncation_result["relation_id_to_original"], ) # Convert keywords strings to lists and add complete metadata to raw_data - hl_keywords_list = hl_keywords.split(', ') if hl_keywords else [] - ll_keywords_list = ll_keywords.split(', ') if ll_keywords else [] + hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] + ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] # Add complete metadata to raw_data (preserve existing metadata including query_mode) - if 'metadata' not in raw_data: - raw_data['metadata'] = {} + if "metadata" not in raw_data: + raw_data["metadata"] = {} # Update keywords while preserving existing metadata - raw_data['metadata']['keywords'] = { - 'high_level': hl_keywords_list, - 'low_level': ll_keywords_list, + raw_data["metadata"]["keywords"] = { + "high_level": hl_keywords_list, + "low_level": ll_keywords_list, } - raw_data['metadata']['processing_info'] = { - 'total_entities_found': len(search_result.get('final_entities', [])), - 'total_relations_found': len(search_result.get('final_relations', [])), - 'entities_after_truncation': len(truncation_result.get('filtered_entities', [])), - 'relations_after_truncation': len(truncation_result.get('filtered_relations', [])), - 'merged_chunks_count': len(merged_chunks), - 'final_chunks_count': len(raw_data.get('data', {}).get('chunks', [])), + raw_data["metadata"]["processing_info"] = { + "total_entities_found": len(search_result.get("final_entities", [])), + "total_relations_found": len(search_result.get("final_relations", [])), + "entities_after_truncation": len( + truncation_result.get("filtered_entities", []) + ), + "relations_after_truncation": len( + truncation_result.get("filtered_relations", []) + ), + "merged_chunks_count": len(merged_chunks), + "final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])), } - # Calculate context sparsity to guide LLM response - # Sparse context = LLM should be more conservative about inferences - entity_count = len(truncation_result.get('filtered_entities', [])) - relation_count = len(truncation_result.get('filtered_relations', [])) - chunk_count = len(raw_data.get('data', {}).get('chunks', [])) - is_sparse = entity_count < 3 or relation_count < 2 or chunk_count < 2 - coverage_level = 'limited' if is_sparse else 'good' - logger.debug( - f'[_build_query_context] Context sparsity: entities={entity_count}, relations={relation_count}, chunks={chunk_count} -> coverage={coverage_level}' + f"[_build_query_context] Context length: {len(context) if context else 0}" + ) + logger.debug( + f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}" ) - logger.debug(f'[_build_query_context] Context length: {len(context) if context else 0}') - logger.debug( - f'[_build_query_context] Raw data entities: {len(raw_data.get("data", {}).get("entities", []))}, relationships: {len(raw_data.get("data", {}).get("relationships", []))}, chunks: {len(raw_data.get("data", {}).get("chunks", []))}' - ) - - return QueryContextResult(context=context, raw_data=raw_data, coverage_level=coverage_level) + return QueryContextResult(context=context, raw_data=raw_data) async def _get_node_data( @@ -4927,7 +4173,9 @@ async def _get_node_data( query_param: QueryParam, ): # get similar entities - logger.info(f'Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})') + logger.info( + f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" + ) results = await entities_vdb.query(query, top_k=query_param.top_k) @@ -4935,7 +4183,7 @@ async def _get_node_data( return [], [] # Extract all entity IDs from your results list - node_ids = [r['entity_name'] for r in results] + node_ids = [r["entity_name"] for r in results] # Call the batch node retrieval and degree functions concurrently. nodes_dict, degrees_dict = await asyncio.gather( @@ -4947,17 +4195,17 @@ async def _get_node_data( node_datas = [nodes_dict.get(nid) for nid in node_ids] node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] - if not all(n is not None for n in node_datas): - logger.warning('Some nodes are missing, maybe the storage is damaged') + if not all([n is not None for n in node_datas]): + logger.warning("Some nodes are missing, maybe the storage is damaged") node_datas = [ { **n, - 'entity_name': k['entity_name'], - 'rank': d, - 'created_at': k.get('created_at'), + "entity_name": k["entity_name"], + "rank": d, + "created_at": k.get("created_at"), } - for k, n, d in zip(results, node_datas, node_degrees, strict=False) + for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] @@ -4967,7 +4215,9 @@ async def _get_node_data( knowledge_graph_inst, ) - logger.info(f'Local query: {len(node_datas)} entites, {len(use_relations)} relations') + logger.info( + f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" + ) # Entities are sorted by cosine similarity # Relations are sorted by rank + weight @@ -4979,7 +4229,7 @@ async def _find_most_related_edges_from_entities( query_param: QueryParam, knowledge_graph_inst: BaseGraphStorage, ): - node_names = [dp['entity_name'] for dp in node_datas] + node_names = [dp["entity_name"] for dp in node_datas] batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) all_edges = [] @@ -4995,7 +4245,7 @@ async def _find_most_related_edges_from_entities( # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. - edge_pairs_dicts = [{'src': e[0], 'tgt': e[1]} for e in all_edges] + edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] # For edge degrees, use tuples. edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples @@ -5010,18 +4260,22 @@ async def _find_most_related_edges_from_entities( for pair in all_edges: edge_props = edge_data_dict.get(pair) if edge_props is not None: - if 'weight' not in edge_props: - logger.warning(f"Edge {pair} missing 'weight' attribute, using default value 1.0") - edge_props['weight'] = 1.0 + if "weight" not in edge_props: + logger.warning( + f"Edge {pair} missing 'weight' attribute, using default value 1.0" + ) + edge_props["weight"] = 1.0 combined = { - 'src_tgt': pair, - 'rank': edge_degrees_dict.get(pair, 0), + "src_tgt": pair, + "rank": edge_degrees_dict.get(pair, 0), **edge_props, } all_edges_data.append(combined) - all_edges_data = sorted(all_edges_data, key=lambda x: (x['rank'], x['weight']), reverse=True) + all_edges_data = sorted( + all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True + ) return all_edges_data @@ -5031,9 +4285,9 @@ async def _find_related_text_unit_from_entities( query_param: QueryParam, text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, - query: str | None = None, - chunks_vdb: BaseVectorStorage | None = None, - chunk_tracking: dict | None = None, + query: str = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, query_embedding=None, ): """ @@ -5043,7 +4297,7 @@ async def _find_related_text_unit_from_entities( 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count 2. VECTOR: Vector similarity-based selection using embedding cosine similarity """ - logger.debug(f'Finding text chunks from {len(node_datas)} entities') + logger.debug(f"Finding text chunks from {len(node_datas)} entities") if not node_datas: return [] @@ -5051,30 +4305,38 @@ async def _find_related_text_unit_from_entities( # Step 1: Collect all text chunks for each entity entities_with_chunks = [] for entity in node_datas: - if entity.get('source_id'): - chunks = split_string_by_multi_markers(entity['source_id'], [GRAPH_FIELD_SEP]) + if entity.get("source_id"): + chunks = split_string_by_multi_markers( + entity["source_id"], [GRAPH_FIELD_SEP] + ) if chunks: entities_with_chunks.append( { - 'entity_name': entity['entity_name'], - 'chunks': chunks, - 'entity_data': entity, + "entity_name": entity["entity_name"], + "chunks": chunks, + "entity_data": entity, } ) if not entities_with_chunks: - logger.warning('No entities with text chunks found') + logger.warning("No entities with text chunks found") return [] - kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD) - max_related_chunks = text_chunks_db.global_config.get('related_chunk_number', DEFAULT_RELATED_CHUNK_NUMBER) + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) chunk_occurrence_count = {} for entity_info in entities_with_chunks: deduplicated_chunks = [] - for chunk_id in entity_info['chunks']: - chunk_occurrence_count[chunk_id] = chunk_occurrence_count.get(chunk_id, 0) + 1 + for chunk_id in entity_info["chunks"]: + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) if chunk_occurrence_count[chunk_id] == 1: @@ -5082,17 +4344,17 @@ async def _find_related_text_unit_from_entities( # count > 1 means this chunk appeared in an earlier entity, so skip it # Update entity's chunks to deduplicated chunks - entity_info['chunks'] = deduplicated_chunks + entity_info["chunks"] = deduplicated_chunks # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) total_entity_chunks = 0 for entity_info in entities_with_chunks: sorted_chunks = sorted( - entity_info['chunks'], + entity_info["chunks"], key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), reverse=True, ) - entity_info['sorted_chunks'] = sorted_chunks + entity_info["sorted_chunks"] = sorted_chunks total_entity_chunks += len(sorted_chunks) selected_chunk_ids = [] # Initialize to avoid UnboundLocalError @@ -5101,14 +4363,14 @@ async def _find_related_text_unit_from_entities( # Pick by vector similarity: # The order of text chunks aligns with the naive retrieval's destination. # When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval. - if kg_chunk_pick_method == 'VECTOR' and query and chunks_vdb: + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) # Get embedding function from global config actual_embedding_func = text_chunks_db.embedding_func if not actual_embedding_func: - logger.warning('No embedding function found, falling back to WEIGHT method') - kg_chunk_pick_method = 'WEIGHT' + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" else: try: selected_chunk_ids = await pick_by_vector_similarity( @@ -5122,50 +4384,56 @@ async def _find_related_text_unit_from_entities( ) if selected_chunk_ids == []: - kg_chunk_pick_method = 'WEIGHT' + kg_chunk_pick_method = "WEIGHT" logger.warning( - 'No entity-related chunks selected by vector similarity, falling back to WEIGHT method' + "No entity-related chunks selected by vector similarity, falling back to WEIGHT method" ) else: logger.info( - f'Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity' + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity" ) except Exception as e: - logger.error(f'Error in vector similarity sorting: {e}, falling back to WEIGHT method') - kg_chunk_pick_method = 'WEIGHT' + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" - if kg_chunk_pick_method == 'WEIGHT': + if kg_chunk_pick_method == "WEIGHT": # Pick by entity and chunk weight: # When reranking is disabled, delivered more solely KG related chunks to the LLM - selected_chunk_ids = pick_by_weighted_polling(entities_with_chunks, max_related_chunks, min_related_chunks=1) + selected_chunk_ids = pick_by_weighted_polling( + entities_with_chunks, max_related_chunks, min_related_chunks=1 + ) logger.info( - f'Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling' + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling" ) if not selected_chunk_ids: return [] # Step 5: Batch retrieve chunk data - unique_chunk_ids = list(dict.fromkeys(selected_chunk_ids)) # Remove duplicates while preserving order + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) # Step 6: Build result chunks with valid data and update chunk tracking result_chunks = [] - for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list, strict=False)): - if chunk_data is not None and 'content' in chunk_data: + for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): + if chunk_data is not None and "content" in chunk_data: chunk_data_copy = chunk_data.copy() - chunk_data_copy['source_type'] = 'entity' - chunk_data_copy['chunk_id'] = chunk_id # Add chunk_id for deduplication + chunk_data_copy["source_type"] = "entity" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication result_chunks.append(chunk_data_copy) # Update chunk tracking if provided if chunk_tracking is not None: chunk_tracking[chunk_id] = { - 'source': 'E', - 'frequency': chunk_occurrence_count.get(chunk_id, 1), - 'order': i + 1, # 1-based order in final entity-related results + "source": "E", + "frequency": chunk_occurrence_count.get(chunk_id, 1), + "order": i + 1, # 1-based order in final entity-related results } return result_chunks @@ -5178,7 +4446,7 @@ async def _get_edge_data( query_param: QueryParam, ): logger.info( - f'Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})' + f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" ) results = await relationships_vdb.query(keywords, top_k=query_param.top_k) @@ -5188,24 +4456,26 @@ async def _get_edge_data( # Prepare edge pairs in two forms: # For the batch edge properties function, use dicts. - edge_pairs_dicts = [{'src': r['src_id'], 'tgt': r['tgt_id']} for r in results] + edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts) # Reconstruct edge_datas list in the same order as results. edge_datas = [] for k in results: - pair = (k['src_id'], k['tgt_id']) + pair = (k["src_id"], k["tgt_id"]) edge_props = edge_data_dict.get(pair) if edge_props is not None: - if 'weight' not in edge_props: - logger.warning(f"Edge {pair} missing 'weight' attribute, using default value 1.0") - edge_props['weight'] = 1.0 + if "weight" not in edge_props: + logger.warning( + f"Edge {pair} missing 'weight' attribute, using default value 1.0" + ) + edge_props["weight"] = 1.0 # Keep edge data without rank, maintain vector search order combined = { - 'src_id': k['src_id'], - 'tgt_id': k['tgt_id'], - 'created_at': k.get('created_at', None), + "src_id": k["src_id"], + "tgt_id": k["tgt_id"], + "created_at": k.get("created_at", None), **edge_props, } edge_datas.append(combined) @@ -5218,7 +4488,9 @@ async def _get_edge_data( knowledge_graph_inst, ) - logger.info(f'Global query: {len(use_entities)} entites, {len(edge_datas)} relations') + logger.info( + f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" + ) return edge_datas, use_entities @@ -5232,12 +4504,12 @@ async def _find_most_related_entities_from_relationships( seen = set() for e in edge_datas: - if e['src_id'] not in seen: - entity_names.append(e['src_id']) - seen.add(e['src_id']) - if e['tgt_id'] not in seen: - entity_names.append(e['tgt_id']) - seen.add(e['tgt_id']) + if e["src_id"] not in seen: + entity_names.append(e["src_id"]) + seen.add(e["src_id"]) + if e["tgt_id"] not in seen: + entity_names.append(e["tgt_id"]) + seen.add(e["tgt_id"]) # Only get nodes data, no need for node degrees nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names) @@ -5250,7 +4522,7 @@ async def _find_most_related_entities_from_relationships( logger.warning(f"Node '{entity_name}' not found in batch retrieval.") continue # Combine the node data with the entity name, no rank needed - combined = {**node, 'entity_name': entity_name} + combined = {**node, "entity_name": entity_name} node_datas.append(combined) return node_datas @@ -5260,10 +4532,10 @@ async def _find_related_text_unit_from_relations( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage, - entity_chunks: list[dict] | None = None, - query: str | None = None, - chunks_vdb: BaseVectorStorage | None = None, - chunk_tracking: dict | None = None, + entity_chunks: list[dict] = None, + query: str = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, query_embedding=None, ): """ @@ -5273,7 +4545,7 @@ async def _find_related_text_unit_from_relations( 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count 2. VECTOR: Vector similarity-based selection using embedding cosine similarity """ - logger.debug(f'Finding text chunks from {len(edge_datas)} relations') + logger.debug(f"Finding text chunks from {len(edge_datas)} relations") if not edge_datas: return [] @@ -5281,29 +4553,37 @@ async def _find_related_text_unit_from_relations( # Step 1: Collect all text chunks for each relationship relations_with_chunks = [] for relation in edge_datas: - if relation.get('source_id'): - chunks = split_string_by_multi_markers(relation['source_id'], [GRAPH_FIELD_SEP]) + if relation.get("source_id"): + chunks = split_string_by_multi_markers( + relation["source_id"], [GRAPH_FIELD_SEP] + ) if chunks: # Build relation identifier - if 'src_tgt' in relation: - rel_key = tuple(sorted(relation['src_tgt'])) + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) else: - rel_key = tuple(sorted([relation.get('src_id', ''), relation.get('tgt_id', '')])) + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) relations_with_chunks.append( { - 'relation_key': rel_key, - 'chunks': chunks, - 'relation_data': relation, + "relation_key": rel_key, + "chunks": chunks, + "relation_data": relation, } ) if not relations_with_chunks: - logger.warning('No relation-related chunks found') + logger.warning("No relation-related chunks found") return [] - kg_chunk_pick_method = text_chunks_db.global_config.get('kg_chunk_pick_method', DEFAULT_KG_CHUNK_PICK_METHOD) - max_related_chunks = text_chunks_db.global_config.get('related_chunk_number', DEFAULT_RELATED_CHUNK_NUMBER) + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships) # Also remove duplicates with entity_chunks @@ -5312,7 +4592,7 @@ async def _find_related_text_unit_from_relations( entity_chunk_ids = set() if entity_chunks: for chunk in entity_chunks: - chunk_id = chunk.get('chunk_id') + chunk_id = chunk.get("chunk_id") if chunk_id: entity_chunk_ids.add(chunk_id) @@ -5322,14 +4602,16 @@ async def _find_related_text_unit_from_relations( for relation_info in relations_with_chunks: deduplicated_chunks = [] - for chunk_id in relation_info['chunks']: + for chunk_id in relation_info["chunks"]: # Skip chunks that already exist in entity_chunks if chunk_id in entity_chunk_ids: # Only count each unique chunk_id once removed_entity_chunk_ids.add(chunk_id) continue - chunk_occurrence_count[chunk_id] = chunk_occurrence_count.get(chunk_id, 0) + 1 + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) if chunk_occurrence_count[chunk_id] == 1: @@ -5337,41 +4619,47 @@ async def _find_related_text_unit_from_relations( # count > 1 means this chunk appeared in an earlier relationship, so skip it # Update relationship's chunks to deduplicated chunks - relation_info['chunks'] = deduplicated_chunks + relation_info["chunks"] = deduplicated_chunks # Check if any relations still have chunks after deduplication - relations_with_chunks = [relation_info for relation_info in relations_with_chunks if relation_info['chunks']] + relations_with_chunks = [ + relation_info + for relation_info in relations_with_chunks + if relation_info["chunks"] + ] if not relations_with_chunks: - logger.info(f'Find no additional relations-related chunks from {len(edge_datas)} relations') + logger.info( + f"Find no additional relations-related chunks from {len(edge_datas)} relations" + ) return [] # Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority) total_relation_chunks = 0 for relation_info in relations_with_chunks: sorted_chunks = sorted( - relation_info['chunks'], + relation_info["chunks"], key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), reverse=True, ) - relation_info['sorted_chunks'] = sorted_chunks + relation_info["sorted_chunks"] = sorted_chunks total_relation_chunks += len(sorted_chunks) logger.info( - f'Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})' + f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})" ) # Step 4: Apply the selected chunk selection algorithm selected_chunk_ids = [] # Initialize to avoid UnboundLocalError - if kg_chunk_pick_method == 'VECTOR' and query and chunks_vdb: + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) # Get embedding function from global config actual_embedding_func = text_chunks_db.embedding_func if not actual_embedding_func: - logger.warning('No embedding function found, falling back to WEIGHT method') - kg_chunk_pick_method = 'WEIGHT' + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" else: try: selected_chunk_ids = await pick_by_vector_similarity( @@ -5385,63 +4673,93 @@ async def _find_related_text_unit_from_relations( ) if selected_chunk_ids == []: - kg_chunk_pick_method = 'WEIGHT' + kg_chunk_pick_method = "WEIGHT" logger.warning( - 'No relation-related chunks selected by vector similarity, falling back to WEIGHT method' + "No relation-related chunks selected by vector similarity, falling back to WEIGHT method" ) else: logger.info( - f'Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity' + f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity" ) except Exception as e: - logger.error(f'Error in vector similarity sorting: {e}, falling back to WEIGHT method') - kg_chunk_pick_method = 'WEIGHT' + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" - if kg_chunk_pick_method == 'WEIGHT': + if kg_chunk_pick_method == "WEIGHT": # Apply linear gradient weighted polling algorithm - selected_chunk_ids = pick_by_weighted_polling(relations_with_chunks, max_related_chunks, min_related_chunks=1) + selected_chunk_ids = pick_by_weighted_polling( + relations_with_chunks, max_related_chunks, min_related_chunks=1 + ) logger.info( - f'Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling' + f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling" ) logger.debug( - f'KG related chunks: {len(entity_chunks or [])} from entitys, {len(selected_chunk_ids)} from relations' + f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" ) if not selected_chunk_ids: return [] # Step 5: Batch retrieve chunk data - unique_chunk_ids = list(dict.fromkeys(selected_chunk_ids)) # Remove duplicates while preserving order + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) # Step 6: Build result chunks with valid data and update chunk tracking result_chunks = [] - for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list, strict=False)): - if chunk_data is not None and 'content' in chunk_data: + for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): + if chunk_data is not None and "content" in chunk_data: chunk_data_copy = chunk_data.copy() - chunk_data_copy['source_type'] = 'relationship' - chunk_data_copy['chunk_id'] = chunk_id # Add chunk_id for deduplication + chunk_data_copy["source_type"] = "relationship" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication result_chunks.append(chunk_data_copy) # Update chunk tracking if provided if chunk_tracking is not None: chunk_tracking[chunk_id] = { - 'source': 'R', - 'frequency': chunk_occurrence_count.get(chunk_id, 1), - 'order': i + 1, # 1-based order in final relation-related results + "source": "R", + "frequency": chunk_occurrence_count.get(chunk_id, 1), + "order": i + 1, # 1-based order in final relation-related results } return result_chunks +@overload +async def naive_query( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + return_raw_data: Literal[True] = True, +) -> dict[str, Any]: ... + + +@overload +async def naive_query( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + return_raw_data: Literal[False] = False, +) -> str | AsyncIterator[str]: ... + + async def naive_query( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, - global_config: dict[str, Any], + global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, ) -> QueryResult | None: @@ -5467,64 +4785,65 @@ async def naive_query( """ if not query: - return QueryResult(content=PROMPTS['fail_response']) + return QueryResult(content=PROMPTS["fail_response"]) if query_param.model_func: use_model_func = query_param.model_func else: - use_model_func = global_config['llm_model_func'] + use_model_func = global_config["llm_model_func"] # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - llm_callable = cast(Callable[..., Awaitable[str | AsyncIterator[str]]], use_model_func) - tokenizer: Tokenizer = global_config['tokenizer'] + tokenizer: Tokenizer = global_config["tokenizer"] if not tokenizer: - logger.error('Tokenizer not found in global configuration.') - return QueryResult(content=PROMPTS['fail_response']) + logger.error("Tokenizer not found in global configuration.") + return QueryResult(content=PROMPTS["fail_response"]) chunks = await _get_vector_context(query, chunks_vdb, query_param, None) if chunks is None or len(chunks) == 0: - logger.info('[naive_query] No relevant document chunks found; returning no-result.') + logger.info( + "[naive_query] No relevant document chunks found; returning no-result." + ) return None # Calculate dynamic token limit for chunks max_total_tokens = getattr( query_param, - 'max_total_tokens', - global_config.get('max_total_tokens', DEFAULT_MAX_TOTAL_TOKENS), + "max_total_tokens", + global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), ) # Calculate system prompt template tokens (excluding content_data) - user_prompt = f'\n\n{query_param.user_prompt}' if query_param.user_prompt else 'n/a' - response_type = query_param.response_type if query_param.response_type else 'Multiple Paragraphs' + user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" + ) # Use the provided system prompt or default - sys_prompt_template = system_prompt if system_prompt else PROMPTS['naive_rag_response'] - - # Detect sparse context for coverage guidance (adapted from KG mode) - chunk_count = len(chunks) if chunks else 0 - is_sparse = chunk_count < 5 # Slightly higher threshold for naive mode - coverage_guidance = ( - PROMPTS['coverage_guidance_limited'] if is_sparse else PROMPTS['coverage_guidance_good'] + sys_prompt_template = ( + system_prompt if system_prompt else PROMPTS["naive_rag_response"] ) # Create a preliminary system prompt with empty content_data to calculate overhead pre_sys_prompt = sys_prompt_template.format( response_type=response_type, user_prompt=user_prompt, - content_data='', # Empty for overhead calculation - coverage_guidance=coverage_guidance, + content_data="", # Empty for overhead calculation ) # Calculate available tokens for chunks sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) query_tokens = len(tokenizer.encode(query)) buffer_tokens = 200 # reserved for reference list and safety buffer - available_chunk_tokens = max_total_tokens - (sys_prompt_tokens + query_tokens + buffer_tokens) + available_chunk_tokens = max_total_tokens - ( + sys_prompt_tokens + query_tokens + buffer_tokens + ) logger.debug( - f'Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}' + f"Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" ) # Process chunks using unified processing with dynamic token limit @@ -5533,14 +4852,16 @@ async def naive_query( unique_chunks=chunks, query_param=query_param, global_config=global_config, - source_type='vector', + source_type="vector", chunk_token_limit=available_chunk_tokens, # Pass dynamic limit ) # Generate reference list from processed chunks using the new common function - reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks(processed_chunks) + reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks( + processed_chunks + ) - logger.info(f'Final context: {len(processed_chunks_with_ref_ids)} chunks') + logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks") # Build raw data structure for naive mode using processed chunks with reference IDs raw_data = convert_to_user_format( @@ -5548,37 +4869,41 @@ async def naive_query( [], # naive mode has no relationships processed_chunks_with_ref_ids, reference_list, - 'naive', + "naive", ) # Add complete metadata for naive mode - if 'metadata' not in raw_data: - raw_data['metadata'] = {} - raw_data['metadata']['keywords'] = { - 'high_level': [], # naive mode has no keyword extraction - 'low_level': [], # naive mode has no keyword extraction + if "metadata" not in raw_data: + raw_data["metadata"] = {} + raw_data["metadata"]["keywords"] = { + "high_level": [], # naive mode has no keyword extraction + "low_level": [], # naive mode has no keyword extraction } - raw_data['metadata']['processing_info'] = { - 'total_chunks_found': len(chunks), - 'final_chunks_count': len(processed_chunks_with_ref_ids), + raw_data["metadata"]["processing_info"] = { + "total_chunks_found": len(chunks), + "final_chunks_count": len(processed_chunks_with_ref_ids), } # Build chunks_context from processed chunks with reference IDs chunks_context = [] - for _i, chunk in enumerate(processed_chunks_with_ref_ids): + for i, chunk in enumerate(processed_chunks_with_ref_ids): chunks_context.append( { - 'reference_id': chunk['reference_id'], - 'content': chunk['content'], + "reference_id": chunk["reference_id"], + "content": chunk["content"], } ) - text_units_str = '\n'.join(json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context) - reference_list_str = '\n'.join( - f'[{ref["reference_id"]}] {ref["file_path"]}' for ref in reference_list if ref['reference_id'] + text_units_str = "\n".join( + json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context + ) + reference_list_str = "\n".join( + f"[{ref['reference_id']}] {ref['file_path']}" + for ref in reference_list + if ref["reference_id"] ) - naive_context_template = PROMPTS['naive_query_context'] + naive_context_template = PROMPTS["naive_query_context"] context_content = naive_context_template.format( text_chunks_str=text_units_str, reference_list_str=reference_list_str, @@ -5591,13 +4916,12 @@ async def naive_query( response_type=query_param.response_type, user_prompt=user_prompt, content_data=context_content, - coverage_guidance=coverage_guidance, ) user_query = query if query_param.only_need_prompt: - prompt_content = '\n\n'.join([sys_prompt, '---User Query---', user_query]) + prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) return QueryResult(content=prompt_content, raw_data=raw_data) # Handle cache @@ -5610,16 +4934,20 @@ async def naive_query( query_param.max_entity_tokens, query_param.max_relation_tokens, query_param.max_total_tokens, - query_param.user_prompt or '', + query_param.user_prompt or "", query_param.enable_rerank, ) - cached_result = await handle_cache(hashing_kv, args_hash, user_query, query_param.mode, cache_type='query') + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" + ) if cached_result is not None: cached_response, _ = cached_result # Extract content, ignore timestamp - logger.info(' == LLM cache == Query cache hit, using cached response as query result') + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) response = cached_response else: - response = await llm_callable( + response = await use_model_func( user_query, system_prompt=sys_prompt, history_messages=query_param.conversation_history, @@ -5627,17 +4955,17 @@ async def naive_query( stream=query_param.stream, ) - if isinstance(response, str) and hashing_kv and hashing_kv.global_config.get('enable_llm_cache'): + if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { - 'mode': query_param.mode, - 'response_type': query_param.response_type, - 'top_k': query_param.top_k, - 'chunk_top_k': query_param.chunk_top_k, - 'max_entity_tokens': query_param.max_entity_tokens, - 'max_relation_tokens': query_param.max_relation_tokens, - 'max_total_tokens': query_param.max_total_tokens, - 'user_prompt': query_param.user_prompt or '', - 'enable_rerank': query_param.enable_rerank, + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, } await save_to_cache( hashing_kv, @@ -5646,7 +4974,7 @@ async def naive_query( content=response, prompt=query, mode=query_param.mode, - cache_type='query', + cache_type="query", queryparam=queryparam_dict, ), ) @@ -5657,16 +4985,18 @@ async def naive_query( if len(response) > len(sys_prompt): response = ( response[len(sys_prompt) :] - .replace(sys_prompt, '') - .replace('user', '') - .replace('model', '') - .replace(query, '') - .replace('', '') - .replace('', '') + .replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") .strip() ) return QueryResult(content=response, raw_data=raw_data) else: # Streaming response (AsyncIterator) - return QueryResult(response_iterator=response, raw_data=raw_data, is_streaming=True) + return QueryResult( + response_iterator=response, raw_data=raw_data, is_streaming=True + ) diff --git a/lightrag/prompt.py b/lightrag/prompt.py index d36a2244..dcd829d4 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -1,260 +1,348 @@ from __future__ import annotations - from typing import Any + PROMPTS: dict[str, Any] = {} -# All delimiters must be formatted as "<|TOKEN|>" style markers (e.g., "<|#|>" or "<|COMPLETE|>") -PROMPTS['DEFAULT_TUPLE_DELIMITER'] = '<|#|>' -PROMPTS['DEFAULT_COMPLETION_DELIMITER'] = '<|COMPLETE|>' +# All delimiters must be formatted as "<|UPPER_CASE_STRING|>" +PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|#|>" +PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" -PROMPTS['entity_extraction_system_prompt'] = """---Role--- -You are a Knowledge Graph Specialist extracting entities and relationships from text. +PROMPTS["entity_extraction_system_prompt"] = """---Role--- +You are a Knowledge Graph Specialist responsible for extracting entities and relationships from the input text. ----Output Format--- -Output raw lines only—NO markdown, NO headers, NO backticks. +---Instructions--- +1. **Entity Extraction & Output:** + * **Identification:** Identify clearly defined and meaningful entities in the input text. + * **Entity Details:** For each identified entity, extract the following information: + * `entity_name`: The name of the entity. If the entity name is case-insensitive, capitalize the first letter of each significant word (title case). Ensure **consistent naming** across the entire extraction process. + * `entity_type`: Categorize the entity using one of the following types: `{entity_types}`. If none of the provided entity types apply, do not add new entity type and classify it as `Other`. + * `entity_description`: Provide a concise yet comprehensive description of the entity's attributes and activities, based *solely* on the information present in the input text. + * **Output Format - Entities:** Output a total of 4 fields for each entity, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `entity`. + * Format: `entity{tuple_delimiter}entity_name{tuple_delimiter}entity_type{tuple_delimiter}entity_description` -Entity: entity{tuple_delimiter}name{tuple_delimiter}type{tuple_delimiter}description -Relation: relation{tuple_delimiter}source{tuple_delimiter}target{tuple_delimiter}keywords{tuple_delimiter}description +2. **Relationship Extraction & Output:** + * **Identification:** Identify direct, clearly stated, and meaningful relationships between previously extracted entities. + * **N-ary Relationship Decomposition:** If a single statement describes a relationship involving more than two entities (an N-ary relationship), decompose it into multiple binary (two-entity) relationship pairs for separate description. + * **Example:** For "Alice, Bob, and Carol collaborated on Project X," extract binary relationships such as "Alice collaborated with Project X," "Bob collaborated with Project X," and "Carol collaborated with Project X," or "Alice collaborated with Bob," based on the most reasonable binary interpretations. + * **Relationship Details:** For each binary relationship, extract the following fields: + * `source_entity`: The name of the source entity. Ensure **consistent naming** with entity extraction. Capitalize the first letter of each significant word (title case) if the name is case-insensitive. + * `target_entity`: The name of the target entity. Ensure **consistent naming** with entity extraction. Capitalize the first letter of each significant word (title case) if the name is case-insensitive. + * `relationship_keywords`: One or more high-level keywords summarizing the overarching nature, concepts, or themes of the relationship. Multiple keywords within this field must be separated by a comma `,`. **DO NOT use `{tuple_delimiter}` for separating multiple keywords within this field.** + * `relationship_description`: A concise explanation of the nature of the relationship between the source and target entities, providing a clear rationale for their connection. + * **Output Format - Relationships:** Output a total of 5 fields for each relationship, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `relation`. + * Format: `relation{tuple_delimiter}source_entity{tuple_delimiter}target_entity{tuple_delimiter}relationship_keywords{tuple_delimiter}relationship_description` -Use Title Case for names. Separate keywords with commas. Output entities first, then relations. End with {completion_delimiter}. +3. **Delimiter Usage Protocol:** + * The `{tuple_delimiter}` is a complete, atomic marker and **must not be filled with content**. It serves strictly as a field separator. + * **Incorrect Example:** `entity{tuple_delimiter}Tokyo<|location|>Tokyo is the capital of Japan.` + * **Correct Example:** `entity{tuple_delimiter}Tokyo{tuple_delimiter}location{tuple_delimiter}Tokyo is the capital of Japan.` ----Entity Extraction--- -Extract BOTH concrete and abstract entities: -- **Concrete:** Named people, organizations, places, products, dates -- **Abstract:** Concepts, events, categories, processes mentioned in text (e.g., "market selloff", "merger", "pandemic") +4. **Relationship Direction & Duplication:** + * Treat all relationships as **undirected** unless explicitly stated otherwise. Swapping the source and target entities for an undirected relationship does not constitute a new relationship. + * Avoid outputting duplicate relationships. -Types: `{entity_types}` (use `Other` if none fit) +5. **Output Order & Prioritization:** + * Output all extracted entities first, followed by all extracted relationships. + * Within the list of relationships, prioritize and output those relationships that are **most significant** to the core meaning of the input text first. ----Relationship Extraction--- -Extract meaningful relationships: -- **Direct:** explicit interactions, actions, connections -- **Categorical:** entities sharing group membership or classification -- **Causal:** cause-effect relationships -- **Hierarchical:** part-of, member-of, type-of +6. **Context & Objectivity:** + * Ensure all entity names and descriptions are written in the **third person**. + * Explicitly name the subject or object; **avoid using pronouns** such as `this article`, `this paper`, `our company`, `I`, `you`, and `he/she`. -Create intermediate concept entities when they help connect related items (e.g., "Vaccines" connecting Pfizer/Moderna/AstraZeneca). +7. **Language & Proper Nouns:** + * The entire output (entity names, keywords, and descriptions) must be written in `{language}`. + * Proper nouns (e.g., personal names, place names, organization names) should be retained in their original language if a proper, widely accepted translation is not available or would cause ambiguity. -For N-ary relationships, decompose into binary pairs. Avoid duplicates. - ----Guidelines--- -- Third person only; no pronouns like "this article", "I", "you" -- Output in `{language}`. Keep proper nouns in original language. +8. **Completion Signal:** Output the literal string `{completion_delimiter}` only after all entities and relationships, following all criteria, have been completely extracted and outputted. ---Examples--- {examples} +""" ----Input--- -Entity_types: [{entity_types}] -Text: +PROMPTS["entity_extraction_user_prompt"] = """---Task--- +Extract entities and relationships from the input text in Data to be Processed below. + +---Instructions--- +1. **Strict Adherence to Format:** Strictly adhere to all format requirements for entity and relationship lists, including output order, field delimiters, and proper noun handling, as specified in the system prompt. +2. **Output Content Only:** Output *only* the extracted list of entities and relationships. Do not include any introductory or concluding remarks, explanations, or additional text before or after the list. +3. **Completion Signal:** Output `{completion_delimiter}` as the final line after all relevant entities and relationships have been extracted and presented. +4. **Output Language:** Ensure the output language is {language}. Proper nouns (e.g., personal names, place names, organization names) must be kept in their original language and not translated. + +---Data to be Processed--- + +[{entity_types}] + + ``` {input_text} ``` -""" - -PROMPTS['entity_extraction_user_prompt'] = """---Task--- -Extract entities and relationships from the text. Include both concrete entities AND abstract concepts/events. - -Follow format exactly. Output only extractions—no explanations. End with `{completion_delimiter}`. -Output in {language}; keep proper nouns in original language. """ -PROMPTS['entity_continue_extraction_user_prompt'] = """---Task--- -Review extraction for missed entities/relationships. +PROMPTS["entity_continue_extraction_user_prompt"] = """---Task--- +Based on the last extraction task, identify and extract any **missed or incorrectly formatted** entities and relationships from the input text. -Check for: -1. Abstract concepts that could serve as hubs (events, categories, processes) -2. Orphan entities that need connections -3. Formatting errors - -Only output NEW or CORRECTED items. End with `{completion_delimiter}`. Output in {language}. +---Instructions--- +1. **Strict Adherence to System Format:** Strictly adhere to all format requirements for entity and relationship lists, including output order, field delimiters, and proper noun handling, as specified in the system instructions. +2. **Focus on Corrections/Additions:** + * **Do NOT** re-output entities and relationships that were **correctly and fully** extracted in the last task. + * If an entity or relationship was **missed** in the last task, extract and output it now according to the system format. + * If an entity or relationship was **truncated, had missing fields, or was otherwise incorrectly formatted** in the last task, re-output the *corrected and complete* version in the specified format. +3. **Output Format - Entities:** Output a total of 4 fields for each entity, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `entity`. +4. **Output Format - Relationships:** Output a total of 5 fields for each relationship, delimited by `{tuple_delimiter}`, on a single line. The first field *must* be the literal string `relation`. +5. **Output Content Only:** Output *only* the extracted list of entities and relationships. Do not include any introductory or concluding remarks, explanations, or additional text before or after the list. +6. **Completion Signal:** Output `{completion_delimiter}` as the final line after all relevant missing or corrected entities and relationships have been extracted and presented. +7. **Output Language:** Ensure the output language is {language}. Proper nouns (e.g., personal names, place names, organization names) must be kept in their original language and not translated. """ -PROMPTS['entity_extraction_examples'] = [ - # Example 1: Shows abstract concept extraction (Market Selloff as hub) - """ +PROMPTS["entity_extraction_examples"] = [ + """ +["Person","Creature","Organization","Location","Event","Concept","Method","Content","Data","Artifact","NaturalObject"] + + ``` -Stock markets faced a sharp downturn as tech giants saw significant declines, with the global tech index dropping 3.4%. Nexon Technologies saw its stock plummet 7.8% after lower-than-expected earnings. In contrast, Omega Energy posted a 2.1% gain driven by rising oil prices. +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. -Gold futures rose 1.5% to $2,080/oz as investors sought safe-haven assets. The Federal Reserve's upcoming policy announcement is expected to influence market stability. +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. "If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us." + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths ``` -entity{tuple_delimiter}Market Selloff{tuple_delimiter}event{tuple_delimiter}Significant decline in stock values due to investor concerns. -entity{tuple_delimiter}Global Tech Index{tuple_delimiter}category{tuple_delimiter}Tracks major tech stocks; dropped 3.4% today. -entity{tuple_delimiter}Nexon Technologies{tuple_delimiter}organization{tuple_delimiter}Tech company whose stock fell 7.8% after disappointing earnings. -entity{tuple_delimiter}Omega Energy{tuple_delimiter}organization{tuple_delimiter}Energy company that gained 2.1% due to rising oil prices. -entity{tuple_delimiter}Gold Futures{tuple_delimiter}product{tuple_delimiter}Rose 1.5% to $2,080/oz as safe-haven investment. -entity{tuple_delimiter}Federal Reserve{tuple_delimiter}organization{tuple_delimiter}Central bank whose policy may impact markets. -relation{tuple_delimiter}Global Tech Index{tuple_delimiter}Market Selloff{tuple_delimiter}market decline{tuple_delimiter}Tech index drop is part of broader selloff. -relation{tuple_delimiter}Nexon Technologies{tuple_delimiter}Market Selloff{tuple_delimiter}tech decline{tuple_delimiter}Nexon among hardest hit in selloff. -relation{tuple_delimiter}Omega Energy{tuple_delimiter}Market Selloff{tuple_delimiter}contrast, resilience{tuple_delimiter}Omega gained while broader market sold off. -relation{tuple_delimiter}Gold Futures{tuple_delimiter}Market Selloff{tuple_delimiter}safe-haven{tuple_delimiter}Gold rose as investors fled stocks. -relation{tuple_delimiter}Federal Reserve{tuple_delimiter}Market Selloff{tuple_delimiter}policy impact{tuple_delimiter}Fed policy expectations contributed to volatility. +entity{tuple_delimiter}Alex{tuple_delimiter}person{tuple_delimiter}Alex is a character who experiences frustration and is observant of the dynamics among other characters. +entity{tuple_delimiter}Taylor{tuple_delimiter}person{tuple_delimiter}Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective. +entity{tuple_delimiter}Jordan{tuple_delimiter}person{tuple_delimiter}Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device. +entity{tuple_delimiter}Cruz{tuple_delimiter}person{tuple_delimiter}Cruz is associated with a vision of control and order, influencing the dynamics among other characters. +entity{tuple_delimiter}The Device{tuple_delimiter}equipment{tuple_delimiter}The Device is central to the story, with potential game-changing implications, and is revered by Taylor. +relation{tuple_delimiter}Alex{tuple_delimiter}Taylor{tuple_delimiter}power dynamics, observation{tuple_delimiter}Alex observes Taylor's authoritarian behavior and notes changes in Taylor's attitude toward the device. +relation{tuple_delimiter}Alex{tuple_delimiter}Jordan{tuple_delimiter}shared goals, rebellion{tuple_delimiter}Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision.) +relation{tuple_delimiter}Taylor{tuple_delimiter}Jordan{tuple_delimiter}conflict resolution, mutual respect{tuple_delimiter}Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce. +relation{tuple_delimiter}Jordan{tuple_delimiter}Cruz{tuple_delimiter}ideological conflict, rebellion{tuple_delimiter}Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order. +relation{tuple_delimiter}Taylor{tuple_delimiter}The Device{tuple_delimiter}reverence, technological significance{tuple_delimiter}Taylor shows reverence towards the device, indicating its importance and potential impact. {completion_delimiter} """, - # Example 2: Shows intermediate entity (Vaccines) connecting multiple orgs - """ + """ +["Person","Creature","Organization","Location","Event","Concept","Method","Content","Data","Artifact","NaturalObject"] + + ``` -COVID-19 vaccines developed by Pfizer, Moderna, and AstraZeneca have shown high efficacy in preventing severe illness. The World Health Organization recommends vaccination for all eligible adults. +Stock markets faced a sharp downturn today as tech giants saw significant declines, with the global tech index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty. + +Among the hardest hit, nexon technologies saw its stock plummet by 7.8% after reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted a modest 2.1% gain, driven by rising oil prices. + +Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%, reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices continued their rally, climbing to $87.60 per barrel, supported by supply constraints and strong demand. + +Financial experts are closely watching the Federal Reserve's next move, as speculation grows over potential rate hikes. The upcoming policy announcement is expected to influence investor confidence and overall market stability. ``` -entity{tuple_delimiter}COVID-19{tuple_delimiter}concept{tuple_delimiter}Disease that vaccines are designed to prevent. -entity{tuple_delimiter}Vaccines{tuple_delimiter}product{tuple_delimiter}Medical products developed to prevent COVID-19. -entity{tuple_delimiter}Pfizer{tuple_delimiter}organization{tuple_delimiter}Pharmaceutical company that developed a COVID-19 vaccine. -entity{tuple_delimiter}Moderna{tuple_delimiter}organization{tuple_delimiter}Pharmaceutical company that developed a COVID-19 vaccine. -entity{tuple_delimiter}AstraZeneca{tuple_delimiter}organization{tuple_delimiter}Pharmaceutical company that developed a COVID-19 vaccine. -entity{tuple_delimiter}World Health Organization{tuple_delimiter}organization{tuple_delimiter}Global health body recommending vaccination. -relation{tuple_delimiter}Pfizer{tuple_delimiter}Vaccines{tuple_delimiter}development{tuple_delimiter}Pfizer developed a COVID-19 vaccine. -relation{tuple_delimiter}Moderna{tuple_delimiter}Vaccines{tuple_delimiter}development{tuple_delimiter}Moderna developed a COVID-19 vaccine. -relation{tuple_delimiter}AstraZeneca{tuple_delimiter}Vaccines{tuple_delimiter}development{tuple_delimiter}AstraZeneca developed a COVID-19 vaccine. -relation{tuple_delimiter}Vaccines{tuple_delimiter}COVID-19{tuple_delimiter}prevention{tuple_delimiter}Vaccines prevent severe COVID-19 illness. -relation{tuple_delimiter}World Health Organization{tuple_delimiter}Vaccines{tuple_delimiter}recommendation{tuple_delimiter}WHO recommends vaccination for adults. +entity{tuple_delimiter}Global Tech Index{tuple_delimiter}category{tuple_delimiter}The Global Tech Index tracks the performance of major technology stocks and experienced a 3.4% decline today. +entity{tuple_delimiter}Nexon Technologies{tuple_delimiter}organization{tuple_delimiter}Nexon Technologies is a tech company that saw its stock decline by 7.8% after disappointing earnings. +entity{tuple_delimiter}Omega Energy{tuple_delimiter}organization{tuple_delimiter}Omega Energy is an energy company that gained 2.1% in stock value due to rising oil prices. +entity{tuple_delimiter}Gold Futures{tuple_delimiter}product{tuple_delimiter}Gold futures rose by 1.5%, indicating increased investor interest in safe-haven assets. +entity{tuple_delimiter}Crude Oil{tuple_delimiter}product{tuple_delimiter}Crude oil prices rose to $87.60 per barrel due to supply constraints and strong demand. +entity{tuple_delimiter}Market Selloff{tuple_delimiter}category{tuple_delimiter}Market selloff refers to the significant decline in stock values due to investor concerns over interest rates and regulations. +entity{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}category{tuple_delimiter}The Federal Reserve's upcoming policy announcement is expected to impact investor confidence and market stability. +entity{tuple_delimiter}3.4% Decline{tuple_delimiter}category{tuple_delimiter}The Global Tech Index experienced a 3.4% decline in midday trading. +relation{tuple_delimiter}Global Tech Index{tuple_delimiter}Market Selloff{tuple_delimiter}market performance, investor sentiment{tuple_delimiter}The decline in the Global Tech Index is part of the broader market selloff driven by investor concerns. +relation{tuple_delimiter}Nexon Technologies{tuple_delimiter}Global Tech Index{tuple_delimiter}company impact, index movement{tuple_delimiter}Nexon Technologies' stock decline contributed to the overall drop in the Global Tech Index. +relation{tuple_delimiter}Gold Futures{tuple_delimiter}Market Selloff{tuple_delimiter}market reaction, safe-haven investment{tuple_delimiter}Gold prices rose as investors sought safe-haven assets during the market selloff. +relation{tuple_delimiter}Federal Reserve Policy Announcement{tuple_delimiter}Market Selloff{tuple_delimiter}interest rate impact, financial regulation{tuple_delimiter}Speculation over Federal Reserve policy changes contributed to market volatility and investor selloff. {completion_delimiter} """, - # Example 3: Short legal example with hub entity (Merger) - """ + """ +["Person","Creature","Organization","Location","Event","Concept","Method","Content","Data","Artifact","NaturalObject"] + + ``` -The merger between Acme Corp and Beta Industries requires Federal Trade Commission approval due to antitrust concerns. +At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes. ``` -entity{tuple_delimiter}Merger{tuple_delimiter}event{tuple_delimiter}Proposed business combination between Acme Corp and Beta Industries. -entity{tuple_delimiter}Acme Corp{tuple_delimiter}organization{tuple_delimiter}Company involved in proposed merger. -entity{tuple_delimiter}Beta Industries{tuple_delimiter}organization{tuple_delimiter}Company involved in proposed merger. -entity{tuple_delimiter}Federal Trade Commission{tuple_delimiter}organization{tuple_delimiter}Regulatory body that must approve the merger. -relation{tuple_delimiter}Acme Corp{tuple_delimiter}Merger{tuple_delimiter}party to{tuple_delimiter}Acme Corp is party to the merger. -relation{tuple_delimiter}Beta Industries{tuple_delimiter}Merger{tuple_delimiter}party to{tuple_delimiter}Beta Industries is party to the merger. -relation{tuple_delimiter}Federal Trade Commission{tuple_delimiter}Merger{tuple_delimiter}regulatory approval{tuple_delimiter}FTC must approve the merger. +entity{tuple_delimiter}World Athletics Championship{tuple_delimiter}event{tuple_delimiter}The World Athletics Championship is a global sports competition featuring top athletes in track and field. +entity{tuple_delimiter}Tokyo{tuple_delimiter}location{tuple_delimiter}Tokyo is the host city of the World Athletics Championship. +entity{tuple_delimiter}Noah Carter{tuple_delimiter}person{tuple_delimiter}Noah Carter is a sprinter who set a new record in the 100m sprint at the World Athletics Championship. +entity{tuple_delimiter}100m Sprint Record{tuple_delimiter}category{tuple_delimiter}The 100m sprint record is a benchmark in athletics, recently broken by Noah Carter. +entity{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}equipment{tuple_delimiter}Carbon-fiber spikes are advanced sprinting shoes that provide enhanced speed and traction. +entity{tuple_delimiter}World Athletics Federation{tuple_delimiter}organization{tuple_delimiter}The World Athletics Federation is the governing body overseeing the World Athletics Championship and record validations. +relation{tuple_delimiter}World Athletics Championship{tuple_delimiter}Tokyo{tuple_delimiter}event location, international competition{tuple_delimiter}The World Athletics Championship is being hosted in Tokyo. +relation{tuple_delimiter}Noah Carter{tuple_delimiter}100m Sprint Record{tuple_delimiter}athlete achievement, record-breaking{tuple_delimiter}Noah Carter set a new 100m sprint record at the championship. +relation{tuple_delimiter}Noah Carter{tuple_delimiter}Carbon-Fiber Spikes{tuple_delimiter}athletic equipment, performance boost{tuple_delimiter}Noah Carter used carbon-fiber spikes to enhance performance during the race. +relation{tuple_delimiter}Noah Carter{tuple_delimiter}World Athletics Championship{tuple_delimiter}athlete participation, competition{tuple_delimiter}Noah Carter is competing at the World Athletics Championship. {completion_delimiter} """, ] -PROMPTS['summarize_entity_descriptions'] = """---Task--- -Merge multiple descriptions of "{description_name}" ({description_type}) into one comprehensive summary. +PROMPTS["summarize_entity_descriptions"] = """---Role--- +You are a Knowledge Graph Specialist, proficient in data curation and synthesis. -Rules: -- Plain text output only, no formatting or extra text -- Include ALL key facts from every description -- Third person, mention entity name at start -- Max {summary_length} tokens -- Output in {language}; keep proper nouns in original language -- If descriptions conflict: reconcile or note uncertainty +---Task--- +Your task is to synthesize a list of descriptions of a given entity or relation into a single, comprehensive, and cohesive summary. + +---Instructions--- +1. Input Format: The description list is provided in JSON format. Each JSON object (representing a single description) appears on a new line within the `Description List` section. +2. Output Format: The merged description will be returned as plain text, presented in multiple paragraphs, without any additional formatting or extraneous comments before or after the summary. +3. Comprehensiveness: The summary must integrate all key information from *every* provided description. Do not omit any important facts or details. +4. Context: Ensure the summary is written from an objective, third-person perspective; explicitly mention the name of the entity or relation for full clarity and context. +5. Context & Objectivity: + - Write the summary from an objective, third-person perspective. + - Explicitly mention the full name of the entity or relation at the beginning of the summary to ensure immediate clarity and context. +6. Conflict Handling: + - In cases of conflicting or inconsistent descriptions, first determine if these conflicts arise from multiple, distinct entities or relationships that share the same name. + - If distinct entities/relations are identified, summarize each one *separately* within the overall output. + - If conflicts within a single entity/relation (e.g., historical discrepancies) exist, attempt to reconcile them or present both viewpoints with noted uncertainty. +7. Length Constraint:The summary's total length must not exceed {summary_length} tokens, while still maintaining depth and completeness. +8. Language: The entire output must be written in {language}. Proper nouns (e.g., personal names, place names, organization names) may in their original language if proper translation is not available. + - The entire output must be written in {language}. + - Proper nouns (e.g., personal names, place names, organization names) should be retained in their original language if a proper, widely accepted translation is not available or would cause ambiguity. + +---Input--- +{description_type} Name: {description_name} + +Description List: -Descriptions: ``` {description_list} ``` -Output:""" - -PROMPTS['fail_response'] = "Sorry, I'm not able to provide an answer to that question.[no-context]" - -# Default RAG response prompt - cite-ready (no LLM-generated citations) -# Citations are added by post-processing. This gives cleaner, more accurate results. -# Optimized via DSPy/RAGAS testing - qtype variant achieved 0.887 relevance, 0.996 faithfulness -PROMPTS['rag_response'] = """Context: -{context_data} - -STRICT GROUNDING RULES (MUST FOLLOW): -- ONLY state information that appears EXPLICITLY in the Context above -- NEVER add specific numbers, percentages, dates, or quantities unless they appear VERBATIM in the Context -- NEVER reference documents, meetings, or sources not mentioned in the Context -- NEVER elaborate, interpret, or infer beyond what the text actually says -- If information is missing, state: "not specified in the available context" -- Each sentence must be directly traceable to a specific passage in the Context - -{coverage_guidance} - -Format Guidelines: -- Use the exact terminology from the question in your response -- The first sentence must directly answer the question -- Response type: {response_type} -- If enumerated items are requested, present as (1), (2), (3)... -- Do not include citation markers; they will be added automatically - -Question: {user_prompt} - -Answer:""" - -# Coverage guidance templates (injected based on context sparsity detection) -PROMPTS['coverage_guidance_limited'] = """ -CONTEXT NOTICE: The retrieved information for this topic is LIMITED. -- Only state facts that appear explicitly in the context below -- If key aspects of the question aren't covered, acknowledge: "The available information does not specify [aspect]" -- Avoid inferring or generalizing beyond what's stated +---Output--- """ -PROMPTS['coverage_guidance_good'] = '' # Empty for well-covered topics +PROMPTS["fail_response"] = ( + "Sorry, I'm not able to provide an answer to that question.[no-context]" +) -# Strict mode suffix - append when response_type="strict" -PROMPTS['rag_response_strict_suffix'] = """ -STRICT GROUNDING: -- NEVER state specific numbers/dates unless they appear EXACTLY in context -- If information isn't in context, say "not specified in available information" -- Entity summaries for overview, Source Excerpts for precision -""" +PROMPTS["rag_response"] = """---Role--- -# Default naive RAG response prompt - cite-ready (no LLM-generated citations) -# Enhanced with strict grounding rules to prevent hallucination -PROMPTS['naive_rag_response'] = """---Task--- -Answer the query using ONLY information present in the provided context. Do NOT add any external knowledge, assumptions, or inference beyond the exact wording. +You are an expert AI assistant specializing in synthesizing information from a provided knowledge base. Your primary function is to answer user queries accurately by ONLY using the information within the provided **Context**. -STRICT GUIDELINES -- Every sentence must be a verbatim fact or a direct logical consequence that can be explicitly traced to a specific chunk of the context. -- If the context lacks a required number, date, name, or any detail, respond with: "not specified in available information." -- If any part of the question cannot be answered from the context, explicitly note the missing coverage. -- Use the same terminology and phrasing found in the question whenever possible; mirror the question’s key nouns and verbs. -- When the answer contains multiple items, present them as a concise list. +---Goal--- -FORMAT -- Match the language of the question. -- Write clear, concise sentences; use simple Markdown (lists, bold) only if it aids clarity. -- Do not include a References section; it will be generated automatically. -- Response type: {response_type} -{coverage_guidance} +Generate a comprehensive, well-structured answer to the user query. +The answer must integrate relevant facts from the Knowledge Graph and Document Chunks found in the **Context**. +Consider the conversation history if provided to maintain conversational flow and avoid repeating information. + +---Instructions--- + +1. Step-by-Step Instruction: + - Carefully determine the user's query intent in the context of the conversation history to fully understand the user's information need. + - Scrutinize both `Knowledge Graph Data` and `Document Chunks` in the **Context**. Identify and extract all pieces of information that are directly relevant to answering the user query. + - Weave the extracted facts into a coherent and logical response. Your own knowledge must ONLY be used to formulate fluent sentences and connect ideas, NOT to introduce any external information. + - Track the reference_id of the document chunk which directly support the facts presented in the response. Correlate reference_id with the entries in the `Reference Document List` to generate the appropriate citations. + - Generate a references section at the end of the response. Each reference document must directly support the facts presented in the response. + - Do not generate anything after the reference section. + +2. Content & Grounding: + - Strictly adhere to the provided context from the **Context**; DO NOT invent, assume, or infer any information not explicitly stated. + - If the answer cannot be found in the **Context**, state that you do not have enough information to answer. Do not attempt to guess. + +3. Formatting & Language: + - The response MUST be in the same language as the user query. + - The response MUST utilize Markdown formatting for enhanced clarity and structure (e.g., headings, bold text, bullet points). + - The response should be presented in {response_type}. + +4. References Section Format: + - The References section should be under heading: `### References` + - Reference list entries should adhere to the format: `* [n] Document Title`. Do not include a caret (`^`) after opening square bracket (`[`). + - The Document Title in the citation must retain its original language. + - Output each citation on an individual line + - Provide maximum of 5 most relevant citations. + - Do not generate footnotes section or any comment, summary, or explanation after the references. + +5. Reference Section Example: +``` +### References + +- [1] Document Title One +- [2] Document Title Two +- [3] Document Title Three +``` + +6. Additional Instructions: {user_prompt} -Question: {user_prompt} ---Context--- + +{context_data} +""" + +PROMPTS["naive_rag_response"] = """---Role--- + +You are an expert AI assistant specializing in synthesizing information from a provided knowledge base. Your primary function is to answer user queries accurately by ONLY using the information within the provided **Context**. + +---Goal--- + +Generate a comprehensive, well-structured answer to the user query. +The answer must integrate relevant facts from the Document Chunks found in the **Context**. +Consider the conversation history if provided to maintain conversational flow and avoid repeating information. + +---Instructions--- + +1. Step-by-Step Instruction: + - Carefully determine the user's query intent in the context of the conversation history to fully understand the user's information need. + - Scrutinize `Document Chunks` in the **Context**. Identify and extract all pieces of information that are directly relevant to answering the user query. + - Weave the extracted facts into a coherent and logical response. Your own knowledge must ONLY be used to formulate fluent sentences and connect ideas, NOT to introduce any external information. + - Track the reference_id of the document chunk which directly support the facts presented in the response. Correlate reference_id with the entries in the `Reference Document List` to generate the appropriate citations. + - Generate a **References** section at the end of the response. Each reference document must directly support the facts presented in the response. + - Do not generate anything after the reference section. + +2. Content & Grounding: + - Strictly adhere to the provided context from the **Context**; DO NOT invent, assume, or infer any information not explicitly stated. + - If the answer cannot be found in the **Context**, state that you do not have enough information to answer. Do not attempt to guess. + +3. Formatting & Language: + - The response MUST be in the same language as the user query. + - The response MUST utilize Markdown formatting for enhanced clarity and structure (e.g., headings, bold text, bullet points). + - The response should be presented in {response_type}. + +4. References Section Format: + - The References section should be under heading: `### References` + - Reference list entries should adhere to the format: `* [n] Document Title`. Do not include a caret (`^`) after opening square bracket (`[`). + - The Document Title in the citation must retain its original language. + - Output each citation on an individual line + - Provide maximum of 5 most relevant citations. + - Do not generate footnotes section or any comment, summary, or explanation after the references. + +5. Reference Section Example: +``` +### References + +- [1] Document Title One +- [2] Document Title Two +- [3] Document Title Three +``` + +6. Additional Instructions: {user_prompt} + + +---Context--- + {content_data} +""" -Answer:""" - -PROMPTS['kg_query_context'] = """ -## Entity Summaries (use for definitions and general facts) +PROMPTS["kg_query_context"] = """ +Knowledge Graph Data (Entity): ```json {entities_str} ``` -## Relationships (use to explain connections between concepts) +Knowledge Graph Data (Relationship): ```json {relations_str} ``` -## Source Excerpts (use for specific facts, numbers, quotes) - -```json -{text_chunks_str} -``` - -## References -{reference_list_str} - -""" - -PROMPTS['naive_query_context'] = """ -Document Chunks (Each entry includes a reference_id that refers to the `Reference Document List`): +Document Chunks (Each entry has a reference_id refer to the `Reference Document List`): ```json {text_chunks_str} @@ -268,75 +356,77 @@ Reference Document List (Each entry starts with a [reference_id] that correspond """ -PROMPTS['keywords_extraction'] = """---Task--- -Extract keywords from the query for RAG retrieval. +PROMPTS["naive_query_context"] = """ +Document Chunks (Each entry has a reference_id refer to the `Reference Document List`): -Output valid JSON (no markdown): -{{"high_level_keywords": [...], "low_level_keywords": [...]}} +```json +{text_chunks_str} +``` -Guidelines: -- high_level: Topic categories, question types, abstract themes -- low_level: Specific terms from the query including: - * Named entities (people, organizations, places) - * Technical terms and key concepts - * Dates, years, and time periods (e.g., "2017", "Q3 2024") - * Document names, report titles, and identifiers -- Extract at least 1 keyword per category for meaningful queries -- Only return empty lists for nonsensical input (e.g., "asdfgh", "hello") +Reference Document List (Each entry starts with a [reference_id] that corresponds to entries in the Document Chunks): + +``` +{reference_list_str} +``` + +""" + +PROMPTS["keywords_extraction"] = """---Role--- +You are an expert keyword extractor, specializing in analyzing user queries for a Retrieval-Augmented Generation (RAG) system. Your purpose is to identify both high-level and low-level keywords in the user's query that will be used for effective document retrieval. + +---Goal--- +Given a user query, your task is to extract two distinct types of keywords: +1. **high_level_keywords**: for overarching concepts or themes, capturing user's core intent, the subject area, or the type of question being asked. +2. **low_level_keywords**: for specific entities or details, identifying the specific entities, proper nouns, technical jargon, product names, or concrete items. + +---Instructions & Constraints--- +1. **Output Format**: Your output MUST be a valid JSON object and nothing else. Do not include any explanatory text, markdown code fences (like ```json), or any other text before or after the JSON. It will be parsed directly by a JSON parser. +2. **Source of Truth**: All keywords must be explicitly derived from the user query, with both high-level and low-level keyword categories are required to contain content. +3. **Concise & Meaningful**: Keywords should be concise words or meaningful phrases. Prioritize multi-word phrases when they represent a single concept. For example, from "latest financial report of Apple Inc.", you should extract "latest financial report" and "Apple Inc." rather than "latest", "financial", "report", and "Apple". +4. **Handle Edge Cases**: For queries that are too simple, vague, or nonsensical (e.g., "hello", "ok", "asdfghjkl"), you must return a JSON object with empty lists for both keyword types. +5. **Language**: All extracted keywords MUST be in {language}. Proper nouns (e.g., personal names, place names, organization names) should be kept in their original language. ---Examples--- {examples} ----Query--- -{query} +---Real Data--- +User Query: {query} +---Output--- Output:""" -PROMPTS['keywords_extraction_examples'] = [ - """Query: "What is the capital of France?" -Output: {{"high_level_keywords": ["Geography", "Capital city"], "low_level_keywords": ["France"]}} +PROMPTS["keywords_extraction_examples"] = [ + """Example 1: + +Query: "How does international trade influence global economic stability?" + +Output: +{ + "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"], + "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] +} + """, - """Query: "Why does inflation affect interest rates?" -Output: {{"high_level_keywords": ["Economics", "Cause-effect"], "low_level_keywords": ["inflation", "interest rates"]}} + """Example 2: + +Query: "What are the environmental consequences of deforestation on biodiversity?" + +Output: +{ + "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], + "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] +} + """, - """Query: "How does Python compare to JavaScript for web development?" -Output: {{"high_level_keywords": ["Programming languages", "Comparison"], "low_level_keywords": ["Python", "JavaScript"]}} + """Example 3: + +Query: "What is the role of education in reducing poverty?" + +Output: +{ + "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], + "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] +} + """, ] - -PROMPTS['orphan_connection_validation'] = """---Task--- -Evaluate if a meaningful relationship exists between two entities. - -Orphan: {orphan_name} ({orphan_type}) - {orphan_description} -Candidate: {candidate_name} ({candidate_type}) - {candidate_description} -Similarity: {similarity_score} - -Valid relationship types: -- Direct: One uses/creates/owns the other -- Industry: Both operate in same sector (finance, tech, healthcare) -- Competitive: Direct competitors or alternatives -- Temporal: Versions, successors, or historical connections -- Dependency: One relies on/runs on the other - -Confidence levels (use these exact labels): -- HIGH: Direct/explicit relationship (Django is Python framework, iOS is Apple product) -- MEDIUM: Strong implicit or industry relationship (Netflix runs on AWS, Bitcoin and Visa both in payments) -- LOW: Very weak, tenuous connection -- NONE: No logical relationship - -Output valid JSON: -{{"should_connect": bool, "confidence": "HIGH"|"MEDIUM"|"LOW"|"NONE", "relationship_type": str|null, "relationship_keywords": str|null, "relationship_description": str|null, "reasoning": str}} - -Rules: -- HIGH/MEDIUM: should_connect=true (same industry = MEDIUM) -- LOW/NONE: should_connect=false -- High similarity alone is NOT sufficient -- Explain the specific relationship in reasoning - -Example: Python↔Django -{{"should_connect": true, "confidence": "HIGH", "relationship_type": "direct", "relationship_keywords": "framework, built-with", "relationship_description": "Django is a web framework written in Python", "reasoning": "Direct explicit relationship - Django is implemented in Python"}} - -Example: Mozart↔Docker -{{"should_connect": false, "confidence": "NONE", "relationship_type": null, "relationship_keywords": null, "relationship_description": null, "reasoning": "No logical connection between classical composer and container technology"}} - -Output:""" diff --git a/lightrag/rerank.py b/lightrag/rerank.py index db7756c7..12950fe6 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -1,211 +1,577 @@ -""" -Local reranker using sentence-transformers CrossEncoder. - -Uses cross-encoder/ms-marco-MiniLM-L-6-v2 by default - a 22M param model with -excellent accuracy and clean score separation (-11 to +10 range). -Runs entirely locally without API calls. -""" - from __future__ import annotations import os -from collections.abc import Awaitable, Callable, Sequence -from typing import Protocol, SupportsFloat, TypedDict, runtime_checkable - +import aiohttp +from typing import Any, List, Dict, Optional, Tuple +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) from .utils import logger -# Global model cache to avoid reloading on every call -_reranker_model: RerankerModel | None = None -_reranker_model_name: str | None = None +from dotenv import load_dotenv -# Default model - mxbai-rerank-xsmall-v1 performs best on domain-specific content -# Used for ordering only (no score filtering) - see constants.py DEFAULT_MIN_RERANK_SCORE -DEFAULT_RERANK_MODEL = 'mixedbread-ai/mxbai-rerank-xsmall-v1' +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) -class RerankResult(TypedDict): - index: int - relevance_score: float - - -@runtime_checkable -class SupportsToList(Protocol): - def tolist(self) -> list[float]: ... - - -ScoreLike = Sequence[SupportsFloat] | SupportsToList - - -@runtime_checkable -class RerankerModel(Protocol): - def predict( - self, - sentences: list[list[str]], - batch_size: int = ..., - ) -> ScoreLike: ... - - -def get_reranker_model(model_name: str | None = None): +def chunk_documents_for_rerank( + documents: List[str], + max_tokens: int = 480, + overlap_tokens: int = 32, + tokenizer_model: str = "gpt-4o-mini", +) -> Tuple[List[str], List[int]]: """ - Get or initialize the reranker model (cached). + Chunk documents that exceed token limit for reranking. Args: - model_name: HuggingFace model name. Defaults to mxbai-rerank-xsmall-v1 + documents: List of document strings to chunk + max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit) + overlap_tokens: Number of tokens to overlap between chunks + tokenizer_model: Model name for tiktoken tokenizer Returns: - CrossEncoder-like model instance implementing predict(pairs)->list[float] + Tuple of (chunked_documents, original_doc_indices) + - chunked_documents: List of document chunks (may be more than input) + - original_doc_indices: Maps each chunk back to its original document index """ - global _reranker_model, _reranker_model_name - - if model_name is None: - model_name = os.getenv('RERANK_MODEL', DEFAULT_RERANK_MODEL) - - # Return cached model if same name - if _reranker_model is not None and _reranker_model_name == model_name: - return _reranker_model + # Clamp overlap_tokens to ensure the loop always advances + # If overlap_tokens >= max_tokens, the chunking loop would hang + if overlap_tokens >= max_tokens: + original_overlap = overlap_tokens + # Ensure overlap is at least 1 token less than max to guarantee progress + # For very small max_tokens (e.g., 1), set overlap to 0 + overlap_tokens = max(0, max_tokens - 1) + logger.warning( + f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " + f"Clamping to {overlap_tokens} to prevent infinite loop." + ) try: - from sentence_transformers import CrossEncoder + from .utils import TiktokenTokenizer - logger.info(f'Loading reranker model: {model_name}') - _reranker_model = CrossEncoder(model_name, trust_remote_code=True) - _reranker_model_name = model_name - logger.info(f'Reranker model loaded: {model_name}') - return _reranker_model - - except ImportError as err: - raise ImportError( - 'sentence-transformers is required for local reranking. Install with: pip install sentence-transformers' - ) from err + tokenizer = TiktokenTokenizer(model_name=tokenizer_model) except Exception as e: - logger.error(f'Failed to load reranker model {model_name}: {e}') - raise + logger.warning( + f"Failed to initialize tokenizer: {e}. Using character-based approximation." + ) + # Fallback: approximate 1 token ≈ 4 characters + max_chars = max_tokens * 4 + overlap_chars = overlap_tokens * 4 + + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + if len(doc) <= max_chars: + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(doc): + end = min(start + max_chars, len(doc)) + chunk = doc[start:end] + chunked_docs.append(chunk) + doc_indices.append(idx) + + if end >= len(doc): + break + start = end - overlap_chars + + return chunked_docs, doc_indices + + # Use tokenizer for accurate chunking + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + tokens = tokenizer.encode(doc) + + if len(tokens) <= max_tokens: + # Document fits in one chunk + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(tokens): + end = min(start + max_tokens, len(tokens)) + chunk_tokens = tokens[start:end] + chunk_text = tokenizer.decode(chunk_tokens) + chunked_docs.append(chunk_text) + doc_indices.append(idx) + + if end >= len(tokens): + break + start = end - overlap_tokens + + return chunked_docs, doc_indices -async def local_rerank( - query: str, - documents: list[str], - top_n: int | None = None, - model_name: str | None = None, -) -> list[RerankResult]: +def aggregate_chunk_scores( + chunk_results: List[Dict[str, Any]], + doc_indices: List[int], + num_original_docs: int, + aggregation: str = "max", +) -> List[Dict[str, Any]]: """ - Rerank documents using local CrossEncoder model. + Aggregate rerank scores from document chunks back to original documents. + + Args: + chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...] + doc_indices: Maps each chunk index to original document index + num_original_docs: Total number of original documents + aggregation: Strategy for aggregating scores ("max", "mean", "first") + + Returns: + List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...] + """ + # Group scores by original document index + doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)} + + for result in chunk_results: + chunk_idx = result["index"] + score = result["relevance_score"] + + if 0 <= chunk_idx < len(doc_indices): + original_doc_idx = doc_indices[chunk_idx] + doc_scores[original_doc_idx].append(score) + + # Aggregate scores + aggregated_results = [] + for doc_idx, scores in doc_scores.items(): + if not scores: + continue + + if aggregation == "max": + final_score = max(scores) + elif aggregation == "mean": + final_score = sum(scores) / len(scores) + elif aggregation == "first": + final_score = scores[0] + else: + logger.warning(f"Unknown aggregation strategy: {aggregation}, using max") + final_score = max(scores) + + aggregated_results.append( + { + "index": doc_idx, + "relevance_score": final_score, + } + ) + + # Sort by relevance score (descending) + aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True) + + return aggregated_results + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(aiohttp.ClientError) + | retry_if_exception_type(aiohttp.ClientResponseError) + ), +) +async def generic_rerank_api( + query: str, + documents: List[str], + model: str, + base_url: str, + api_key: Optional[str], + top_n: Optional[int] = None, + return_documents: Optional[bool] = None, + extra_body: Optional[Dict[str, Any]] = None, + response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + enable_chunking: bool = False, + max_tokens_per_doc: int = 480, +) -> List[Dict[str, Any]]: + """ + Generic rerank API call for Jina/Cohere/Aliyun models. Args: query: The search query - documents: List of document strings to rerank - top_n: Number of top results to return (None = all) - model_name: HuggingFace model name (default: mxbai-rerank-xsmall-v1) + documents: List of strings to rerank + model: Model name to use + base_url: API endpoint URL + api_key: API key for authentication + top_n: Number of top results to return + return_documents: Whether to return document text (Jina only) + extra_body: Additional body parameters + response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun) + request_format: Request format type + enable_chunking: Whether to chunk documents exceeding token limit + max_tokens_per_doc: Maximum tokens per document for chunking Returns: - List of dicts with 'index' and 'relevance_score', sorted by score descending - - Example: - >>> results = await local_rerank( - ... query="What is machine learning?", - ... documents=["ML is a subset of AI...", "The weather is nice..."], - ... top_n=5 - ... ) - >>> print(results[0]) - {'index': 0, 'relevance_score': 0.95} + List of dictionary of ["index": int, "relevance_score": float] """ - if not documents: - return [] + if not base_url: + raise ValueError("Base URL is required") - model = get_reranker_model(model_name) + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" - # Create query-document pairs - pairs = [[query, doc] for doc in documents] + # Handle document chunking if enabled + original_documents = documents + doc_indices = None + original_top_n = top_n # Save original top_n for post-aggregation limiting - # Get scores from model - # CrossEncoder.predict returns a list[float]; guard None for type checkers - if model is None: - raise RuntimeError('Reranker model failed to load') - raw_scores = model.predict(pairs) + if enable_chunking: + documents, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=max_tokens_per_doc + ) + logger.debug( + f"Chunked {len(original_documents)} documents into {len(documents)} chunks" + ) + # When chunking is enabled, disable top_n at API level to get all chunk scores + # This ensures proper document-level coverage after aggregation + # We'll apply top_n to aggregated document results instead + if top_n is not None: + logger.debug( + f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage" + ) + top_n = None - # Normalize to a list[float] regardless of backend (list, numpy array, tensor) - if isinstance(raw_scores, SupportsToList): - raw_scores = raw_scores.tolist() + # Build request payload based on request format + if request_format == "aliyun": + # Aliyun format: nested input/parameters structure + payload = { + "model": model, + "input": { + "query": query, + "documents": documents, + }, + "parameters": {}, + } - scores = [float(score) for score in raw_scores] + # Add optional parameters to parameters object + if top_n is not None: + payload["parameters"]["top_n"] = top_n - # Build results with index and score - results: list[RerankResult] = [ - RerankResult(index=i, relevance_score=float(score)) for i, score in enumerate(scores) - ] + if return_documents is not None: + payload["parameters"]["return_documents"] = return_documents - # Sort by score descending - results.sort(key=lambda x: x['relevance_score'], reverse=True) + # Add extra parameters to parameters object + if extra_body: + payload["parameters"].update(extra_body) + else: + # Standard format for Jina/Cohere/OpenAI + payload = { + "model": model, + "query": query, + "documents": documents, + } - # Apply top_n limit if specified - if top_n is not None and top_n < len(results): - results = results[:top_n] + # Add optional parameters + if top_n is not None: + payload["top_n"] = top_n - return results + # Only Jina API supports return_documents parameter + if return_documents is not None and response_format in ("standard",): + payload["return_documents"] = return_documents + + # Add extra parameters + if extra_body: + payload.update(extra_body) + + logger.debug( + f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}" + ) + + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=payload) as response: + if response.status != 200: + error_text = await response.text() + content_type = response.headers.get("content-type", "").lower() + is_html_error = ( + error_text.strip().startswith("") + or "text/html" in content_type + ) + if is_html_error: + if response.status == 502: + clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes." + elif response.status == 503: + clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later." + elif response.status == 504: + clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again." + else: + clean_error = f"HTTP {response.status} - Rerank service error. Please try again later." + else: + clean_error = error_text + logger.error(f"Rerank API error {response.status}: {clean_error}") + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + status=response.status, + message=f"Rerank API error: {clean_error}", + ) + + response_json = await response.json() + + if response_format == "aliyun": + # Aliyun format: {"output": {"results": [...]}} + results = response_json.get("output", {}).get("results", []) + if not isinstance(results, list): + logger.warning( + f"Expected 'output.results' to be list, got {type(results)}: {results}" + ) + results = [] + elif response_format == "standard": + # Standard format: {"results": [...]} + results = response_json.get("results", []) + if not isinstance(results, list): + logger.warning( + f"Expected 'results' to be list, got {type(results)}: {results}" + ) + results = [] + else: + raise ValueError(f"Unsupported response format: {response_format}") + + if not results: + logger.warning("Rerank API returned empty results") + return [] + + # Standardize return format + standardized_results = [ + {"index": result["index"], "relevance_score": result["relevance_score"]} + for result in results + ] + + # Aggregate chunk scores back to original documents if chunking was enabled + if enable_chunking and doc_indices: + standardized_results = aggregate_chunk_scores( + standardized_results, + doc_indices, + len(original_documents), + aggregation="max", + ) + # Apply original top_n limit at document level (post-aggregation) + # This preserves document-level semantics: top_n limits documents, not chunks + if ( + original_top_n is not None + and len(standardized_results) > original_top_n + ): + standardized_results = standardized_results[:original_top_n] + + return standardized_results -def create_local_rerank_func( - model_name: str | None = None, -) -> Callable[..., Awaitable[list[RerankResult]]]: +async def cohere_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "rerank-v3.5", + base_url: str = "https://api.cohere.com/v2/rerank", + extra_body: Optional[Dict[str, Any]] = None, + enable_chunking: bool = False, + max_tokens_per_doc: int = 4096, +) -> List[Dict[str, Any]]: """ - Create a rerank function with pre-configured model. + Rerank documents using Cohere API. - This is used by lightrag_server to create a rerank function - that can be passed to LightRAG initialization. + Supports both standard Cohere API and Cohere-compatible proxies Args: - model_name: HuggingFace model name (default: mxbai-rerank-xsmall-v1) + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: API key for authentication + model: rerank model name (default: rerank-v3.5) + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) + enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc + max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5) Returns: - Async rerank function + List of dictionary of ["index": int, "relevance_score": float] + + Example: + >>> # Standard Cohere API + >>> results = await cohere_rerank( + ... query="What is the meaning of life?", + ... documents=["Doc1", "Doc2"], + ... api_key="your-cohere-key" + ... ) + + >>> # LiteLLM proxy with user authentication + >>> results = await cohere_rerank( + ... query="What is vector search?", + ... documents=["Doc1", "Doc2"], + ... model="answerai-colbert-small-v1", + ... base_url="https://llm-proxy.example.com/v2/rerank", + ... api_key="your-proxy-key", + ... enable_chunking=True, + ... max_tokens_per_doc=480 + ... ) """ - # Pre-load model to fail fast if there's an issue - get_reranker_model(model_name) + if api_key is None: + api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") - async def rerank_func( - query: str, - documents: list[str], - top_n: int | None = None, - **kwargs, - ) -> list[RerankResult]: - return await local_rerank( - query=query, - documents=documents, - top_n=top_n, - model_name=model_name, - ) - - return rerank_func + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_n=top_n, + return_documents=None, # Cohere doesn't support this parameter + extra_body=extra_body, + response_format="standard", + enable_chunking=enable_chunking, + max_tokens_per_doc=max_tokens_per_doc, + ) -# For backwards compatibility - alias to local_rerank -rerank = local_rerank +async def jina_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "jina-reranker-v2-base-multilingual", + base_url: str = "https://api.jina.ai/v1/rerank", + extra_body: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Rerank documents using Jina AI API. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: API key + model: rerank model name + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) + + Returns: + List of dictionary of ["index": int, "relevance_score": float] + """ + if api_key is None: + api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_n=top_n, + return_documents=False, + extra_body=extra_body, + response_format="standard", + ) -if __name__ == '__main__': +async def ali_rerank( + query: str, + documents: List[str], + top_n: Optional[int] = None, + api_key: Optional[str] = None, + model: str = "gte-rerank-v2", + base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank", + extra_body: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Rerank documents using Aliyun DashScope API. + + Args: + query: The search query + documents: List of strings to rerank + top_n: Number of top results to return + api_key: Aliyun API key + model: rerank model name + base_url: API endpoint + extra_body: Additional body for http request(reserved for extra params) + + Returns: + List of dictionary of ["index": int, "relevance_score": float] + """ + if api_key is None: + api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_n=top_n, + return_documents=False, # Aliyun doesn't need this parameter + extra_body=extra_body, + response_format="aliyun", + request_format="aliyun", + ) + + +"""Please run this test as a module: +python -m lightrag.rerank +""" +if __name__ == "__main__": import asyncio async def main(): + # Example usage - documents should be strings, not dictionaries docs = [ - 'The capital of France is Paris.', - 'Tokyo is the capital of Japan.', - 'London is the capital of England.', - 'Python is a programming language.', + "The capital of France is Paris.", + "Tokyo is the capital of Japan.", + "London is the capital of England.", ] - query = 'What is the capital of France?' + query = "What is the capital of France?" - print('=== Local Reranker Test ===') - print(f'Model: {os.getenv("RERANK_MODEL", DEFAULT_RERANK_MODEL)}') - print(f'Query: {query}') - print() + # Test Jina rerank + try: + print("=== Jina Rerank ===") + result = await jina_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Jina Error: {e}") - results = await local_rerank(query=query, documents=docs, top_n=3) + # Test Cohere rerank + try: + print("\n=== Cohere Rerank ===") + result = await cohere_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Cohere Error: {e}") - print('Results (top 3):') - for item in results: - idx = item['index'] - score = item['relevance_score'] - print(f' [{idx}] Score: {score:.4f} - {docs[idx]}') + # Test Aliyun rerank + try: + print("\n=== Aliyun Rerank ===") + result = await ali_rerank( + query=query, + documents=docs, + top_n=2, + ) + print("Results:") + for item in result: + print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}") + print(f"Document: {docs[item['index']]}") + except Exception as e: + print(f"Aliyun Error: {e}") asyncio.run(main())