This commit is contained in:
Raphaël MANSUY 2025-12-04 19:18:16 +08:00
parent 0a6e4616b2
commit d85c5a5875
3 changed files with 446 additions and 1359 deletions

View file

@ -73,6 +73,8 @@ ENABLE_LLM_CACHE=true
# MAX_RELATION_TOKENS=8000
### control the maximum tokens send to LLM (include entities, relations and chunks)
# MAX_TOTAL_TOKENS=30000
### control the maximum chunk_ids stored in vector db
# MAX_CHUNK_IDS_PER_ENTITY=500
### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)

File diff suppressed because it is too large Load diff

View file

@ -15,17 +15,7 @@ from dataclasses import dataclass
from datetime import datetime
from functools import wraps
from hashlib import md5
from typing import (
Any,
Protocol,
Callable,
TYPE_CHECKING,
List,
Optional,
Iterable,
Sequence,
Collection,
)
from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional
import numpy as np
from dotenv import load_dotenv
@ -35,9 +25,8 @@ from lightrag.constants import (
DEFAULT_LOG_FILENAME,
GRAPH_FIELD_SEP,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_SOURCE_IDS_LIMIT_METHOD,
VALID_SOURCE_IDS_LIMIT_METHODS,
SOURCE_IDS_LIMIT_METHOD_FIFO,
DEFAULT_MAX_FILE_PATH_LENGTH,
DEFAULT_MAX_CHUNK_IDS_PER_ENTITY,
)
# Initialize logger with basic configuration
@ -353,29 +342,8 @@ class EmbeddingFunc:
embedding_dim: int
func: callable
max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions:
# Check if user provided embedding_dim parameter
if "embedding_dim" in kwargs:
user_provided_dim = kwargs["embedding_dim"]
# If user's value differs from class attribute, output warning
if (
user_provided_dim is not None
and user_provided_dim != self.embedding_dim
):
logger.warning(
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
f"using declared embedding_dim={self.embedding_dim} from decorator"
)
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs)
@ -927,45 +895,9 @@ def load_json(file_name):
return json.load(f)
def _sanitize_json_data(data: Any) -> Any:
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
Handles all JSON-serializable types including:
- Dictionary keys and values
- Lists and tuples (preserves type)
- Nested structures
- Strings at any level
Args:
data: Data to sanitize (dict, list, tuple, str, or other types)
Returns:
Sanitized data with all strings cleaned of problematic characters
"""
if isinstance(data, dict):
# Sanitize both keys and values
return {
_sanitize_string_for_json(k)
if isinstance(k, str)
else k: _sanitize_json_data(v)
for k, v in data.items()
}
elif isinstance(data, (list, tuple)):
# Handle both lists and tuples, preserve original type
sanitized = [_sanitize_json_data(item) for item in data]
return type(data)(sanitized)
elif isinstance(data, str):
return sanitize_text_for_encoding(data, replacement_char="")
else:
# Numbers, booleans, None, etc. - return as-is
return data
def write_json(json_obj, file_name):
# Sanitize data before writing to prevent UTF-8 encoding errors
sanitized_obj = _sanitize_json_data(json_obj)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
json.dump(json_obj, f, indent=2, ensure_ascii=False)
class TokenizerInterface(Protocol):
@ -1852,7 +1784,7 @@ def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str:
- Filter out short numeric-only text (length < 3 and only digits/dots)
- remove_inner_quotes = True
remove Chinese quotes
remove English quotes in and around chinese
remove English queotes in and around chinese
Convert non-breaking spaces to regular spaces
Convert narrow non-breaking spaces after non-digits to regular spaces
@ -2533,157 +2465,80 @@ async def process_chunks_unified(
return final_chunks
def truncate_entity_source_id(chunk_ids: set, entity_name: str) -> set:
"""Limit chunk_ids, for entities that appear a HUGE no of times (To not break VDB hard upper limits)"""
already_len: int = len(chunk_ids)
def normalize_source_ids_limit_method(method: str | None) -> str:
"""Normalize the source ID limiting strategy and fall back to default when invalid."""
max_chunk_ids_per_entity = get_env_value("MAX_CHUNK_IDS_PER_ENTITY", DEFAULT_MAX_CHUNK_IDS_PER_ENTITY, int)
if not method:
return DEFAULT_SOURCE_IDS_LIMIT_METHOD
normalized = method.upper()
if normalized not in VALID_SOURCE_IDS_LIMIT_METHODS:
if already_len >= max_chunk_ids_per_entity:
logger.warning(
"Unknown SOURCE_IDS_LIMIT_METHOD '%s', falling back to %s",
method,
DEFAULT_SOURCE_IDS_LIMIT_METHOD,
f"Chunk Ids already exceeds {max_chunk_ids_per_entity } for {entity_name}, "
f"current size: {already_len} entries."
)
return DEFAULT_SOURCE_IDS_LIMIT_METHOD
truncated_chunk_ids = set(list(chunk_ids)[0:max_chunk_ids_per_entity ])
return normalized
return truncated_chunk_ids
def merge_source_ids(
existing_ids: Iterable[str] | None, new_ids: Iterable[str] | None
) -> list[str]:
"""Merge two iterables of source IDs while preserving order and removing duplicates."""
merged: list[str] = []
seen: set[str] = set()
for sequence in (existing_ids, new_ids):
if not sequence:
continue
for source_id in sequence:
if not source_id:
continue
if source_id not in seen:
seen.add(source_id)
merged.append(source_id)
return merged
def apply_source_ids_limit(
source_ids: Sequence[str],
limit: int,
method: str,
*,
identifier: str | None = None,
) -> list[str]:
"""Apply a limit strategy to a sequence of source IDs."""
if limit <= 0:
return []
source_ids_list = list(source_ids)
if len(source_ids_list) <= limit:
return source_ids_list
normalized_method = normalize_source_ids_limit_method(method)
if normalized_method == SOURCE_IDS_LIMIT_METHOD_FIFO:
truncated = source_ids_list[-limit:]
else: # IGNORE_NEW
truncated = source_ids_list[:limit]
if identifier and len(truncated) < len(source_ids_list):
logger.debug(
"Source_id truncated: %s | %s keeping %s of %s entries",
identifier,
normalized_method,
len(truncated),
len(source_ids_list),
)
return truncated
def compute_incremental_chunk_ids(
existing_full_chunk_ids: list[str],
old_chunk_ids: list[str],
new_chunk_ids: list[str],
) -> list[str]:
"""
Compute incrementally updated chunk IDs based on changes.
This function applies delta changes (additions and removals) to an existing
list of chunk IDs while maintaining order and ensuring deduplication.
Delta additions from new_chunk_ids are placed at the end.
def build_file_path(already_file_paths, data_list, target):
"""Build file path string with UTF-8 byte length limit and deduplication
Args:
existing_full_chunk_ids: Complete list of existing chunk IDs from storage
old_chunk_ids: Previous chunk IDs from source_id (chunks being replaced)
new_chunk_ids: New chunk IDs from updated source_id (chunks being added)
already_file_paths: List of existing file paths
data_list: List of data items containing file_path
target: Target name for logging warnings
Returns:
Updated list of chunk IDs with deduplication
Example:
>>> existing = ['chunk-1', 'chunk-2', 'chunk-3']
>>> old = ['chunk-1', 'chunk-2']
>>> new = ['chunk-2', 'chunk-4']
>>> compute_incremental_chunk_ids(existing, old, new)
['chunk-3', 'chunk-2', 'chunk-4']
str: Combined file paths separated by GRAPH_FIELD_SEP
"""
# Calculate changes
chunks_to_remove = set(old_chunk_ids) - set(new_chunk_ids)
chunks_to_add = set(new_chunk_ids) - set(old_chunk_ids)
# set: deduplication
file_paths_set = {fp for fp in already_file_paths if fp}
# Apply changes to full chunk_ids
# Step 1: Remove chunks that are no longer needed
updated_chunk_ids = [
cid for cid in existing_full_chunk_ids if cid not in chunks_to_remove
]
# string: filter empty value and keep file order in already_file_paths
file_paths = GRAPH_FIELD_SEP.join(fp for fp in already_file_paths if fp)
# Step 2: Add new chunks (preserving order from new_chunk_ids)
# Note: 'cid not in updated_chunk_ids' check ensures deduplication
for cid in new_chunk_ids:
if cid in chunks_to_add and cid not in updated_chunk_ids:
updated_chunk_ids.append(cid)
# Check if initial file_paths already exceeds byte length limit
if len(file_paths.encode("utf-8")) >= DEFAULT_MAX_FILE_PATH_LENGTH:
logger.warning(
f"Initial file_paths already exceeds {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, "
f"current size: {len(file_paths.encode('utf-8'))} bytes"
)
return updated_chunk_ids
# ignored file_paths
file_paths_ignore = ""
# add file_paths
for dp in data_list:
cur_file_path = dp.get("file_path")
# empty
if not cur_file_path:
continue
# skip duplicate item
if cur_file_path in file_paths_set:
continue
# add
file_paths_set.add(cur_file_path)
def subtract_source_ids(
source_ids: Iterable[str],
ids_to_remove: Collection[str],
) -> list[str]:
"""Remove a collection of IDs from an ordered iterable while preserving order."""
# check the UTF-8 byte length
new_addition = GRAPH_FIELD_SEP + cur_file_path if file_paths else cur_file_path
if (
len(file_paths.encode("utf-8")) + len(new_addition.encode("utf-8"))
< DEFAULT_MAX_FILE_PATH_LENGTH - 5
):
# append
file_paths += new_addition
else:
# ignore
file_paths_ignore += GRAPH_FIELD_SEP + cur_file_path
removal_set = set(ids_to_remove)
if not removal_set:
return [source_id for source_id in source_ids if source_id]
return [
source_id
for source_id in source_ids
if source_id and source_id not in removal_set
]
def make_relation_chunk_key(src: str, tgt: str) -> str:
"""Create a deterministic storage key for relation chunk tracking."""
return GRAPH_FIELD_SEP.join(sorted((src, tgt)))
def parse_relation_chunk_key(key: str) -> tuple[str, str]:
"""Parse a relation chunk storage key back into its entity pair."""
parts = key.split(GRAPH_FIELD_SEP)
if len(parts) != 2:
raise ValueError(f"Invalid relation chunk key: {key}")
return parts[0], parts[1]
if file_paths_ignore:
logger.warning(
f"File paths exceed {DEFAULT_MAX_FILE_PATH_LENGTH} bytes for {target}, "
f"ignoring file path: {file_paths_ignore}"
)
return file_paths
def generate_track_id(prefix: str = "upload") -> str:
@ -2773,9 +2628,9 @@ def fix_tuple_delimiter_corruption(
record,
)
# Fix: <X|#|> -> <|#|>, <|#|Y> -> <|#|>, <X|#|Y> -> <|#|>, <||#||> -> <|#|> (one extra characters outside pipes)
# Fix: <X|#|> -> <|#|>, <|#|Y> -> <|#|>, <X|#|Y> -> <|#|>, <||#||> -> <|#|>, <||#> -> <|#|> (one extra characters outside pipes)
record = re.sub(
rf"<.?\|{escaped_delimiter_core}\|.?>",
rf"<.?\|{escaped_delimiter_core}\|*?>",
tuple_delimiter,
record,
)
@ -2795,6 +2650,7 @@ def fix_tuple_delimiter_corruption(
)
# Fix: <|#| -> <|#|>, <|#|| -> <|#|> (missing closing >)
record = re.sub(
rf"<\|{escaped_delimiter_core}\|+(?!>)",
tuple_delimiter,
@ -2808,13 +2664,6 @@ def fix_tuple_delimiter_corruption(
record,
)
# Fix: <||#> -> <|#|> (double pipe at start, missing pipe at end)
record = re.sub(
rf"<\|+{escaped_delimiter_core}>",
tuple_delimiter,
record,
)
# Fix: <|| -> <|#|>
record = re.sub(
r"<\|\|(?!>)",