Add chunk tracking support to entity merge functionality

- Pass chunk storages to merge function
- Merge relation chunk tracking data
- Merge entity chunk tracking data
- Delete old chunk tracking records
- Persist chunk storage updates

(cherry picked from commit 2c09adb8d3)
This commit is contained in:
yangdx 2025-10-27 02:06:21 +08:00 committed by Raphaël MANSUY
parent 450f969430
commit 17a9771cfb
2 changed files with 348 additions and 257 deletions

View file

@ -3575,6 +3575,8 @@ class LightRAG:
target_entity, target_entity,
merge_strategy, merge_strategy,
target_entity_data, target_entity_data,
self.entity_chunks,
self.relation_chunks,
) )
def merge_entities( def merge_entities(

View file

@ -11,6 +11,49 @@ from .utils import compute_mdhash_id, logger
from .base import StorageNameSpace from .base import StorageNameSpace
async def _persist_graph_updates(
entities_vdb=None,
relationships_vdb=None,
chunk_entity_relation_graph=None,
entity_chunks_storage=None,
relation_chunks_storage=None,
) -> None:
"""Unified callback to persist updates after graph operations.
Ensures all relevant storage instances are properly persisted after
operations like delete, edit, create, or merge.
Args:
entities_vdb: Entity vector database storage (optional)
relationships_vdb: Relationship vector database storage (optional)
chunk_entity_relation_graph: Graph storage instance (optional)
entity_chunks_storage: Entity-chunk tracking storage (optional)
relation_chunks_storage: Relation-chunk tracking storage (optional)
"""
storages = []
# Collect all non-None storage instances
if entities_vdb is not None:
storages.append(entities_vdb)
if relationships_vdb is not None:
storages.append(relationships_vdb)
if chunk_entity_relation_graph is not None:
storages.append(chunk_entity_relation_graph)
if entity_chunks_storage is not None:
storages.append(entity_chunks_storage)
if relation_chunks_storage is not None:
storages.append(relation_chunks_storage)
# Persist all storage instances in parallel
if storages:
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in storages # type: ignore
]
)
async def adelete_by_entity( async def adelete_by_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@ -64,7 +107,9 @@ async def adelete_by_entity(
for src, tgt in edges: for src, tgt in edges:
# Normalize entity order for consistent key generation # Normalize entity order for consistent key generation
normalized_src, normalized_tgt = sorted([src, tgt]) normalized_src, normalized_tgt = sorted([src, tgt])
storage_key = make_relation_chunk_key(normalized_src, normalized_tgt) storage_key = make_relation_chunk_key(
normalized_src, normalized_tgt
)
relation_keys_to_delete.append(storage_key) relation_keys_to_delete.append(storage_key)
if relation_keys_to_delete: if relation_keys_to_delete:
@ -79,12 +124,12 @@ async def adelete_by_entity(
message = f"Entity Delete: remove '{entity_name}' and its {related_relations_count} relations" message = f"Entity Delete: remove '{entity_name}' and its {related_relations_count} relations"
logger.info(message) logger.info(message)
await _delete_by_entity_done( await _persist_graph_updates(
entities_vdb, entities_vdb=entities_vdb,
relationships_vdb, relationships_vdb=relationships_vdb,
chunk_entity_relation_graph, chunk_entity_relation_graph=chunk_entity_relation_graph,
entity_chunks_storage, entity_chunks_storage=entity_chunks_storage,
relation_chunks_storage, relation_chunks_storage=relation_chunks_storage,
) )
return DeletionResult( return DeletionResult(
status="success", status="success",
@ -103,28 +148,6 @@ async def adelete_by_entity(
) )
async def _delete_by_entity_done(
entities_vdb,
relationships_vdb,
chunk_entity_relation_graph,
entity_chunks_storage=None,
relation_chunks_storage=None,
) -> None:
"""Callback after entity deletion is complete, ensures updates are persisted"""
storages = [entities_vdb, relationships_vdb, chunk_entity_relation_graph]
if entity_chunks_storage is not None:
storages.append(entity_chunks_storage)
if relation_chunks_storage is not None:
storages.append(relation_chunks_storage)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in storages # type: ignore
]
)
async def adelete_by_relation( async def adelete_by_relation(
chunk_entity_relation_graph, chunk_entity_relation_graph,
relationships_vdb, relationships_vdb,
@ -148,6 +171,10 @@ async def adelete_by_relation(
# Use graph database lock to ensure atomic graph and vector db operations # Use graph database lock to ensure atomic graph and vector db operations
async with graph_db_lock: async with graph_db_lock:
try: try:
# Normalize entity order for undirected graph (ensures consistent key generation)
if source_entity > target_entity:
source_entity, target_entity = target_entity, source_entity
# Check if the relation exists # Check if the relation exists
edge_exists = await chunk_entity_relation_graph.has_edge( edge_exists = await chunk_entity_relation_graph.has_edge(
source_entity, target_entity source_entity, target_entity
@ -169,7 +196,7 @@ async def adelete_by_relation(
# Normalize entity order for consistent key generation # Normalize entity order for consistent key generation
normalized_src, normalized_tgt = sorted([source_entity, target_entity]) normalized_src, normalized_tgt = sorted([source_entity, target_entity])
storage_key = make_relation_chunk_key(normalized_src, normalized_tgt) storage_key = make_relation_chunk_key(normalized_src, normalized_tgt)
await relation_chunks_storage.delete([storage_key]) await relation_chunks_storage.delete([storage_key])
logger.info( logger.info(
f"Relation Delete: removed chunk tracking for `{source_entity}`~`{target_entity}`" f"Relation Delete: removed chunk tracking for `{source_entity}`~`{target_entity}`"
@ -190,8 +217,10 @@ async def adelete_by_relation(
message = f"Relation Delete: `{source_entity}`~`{target_entity}` deleted successfully" message = f"Relation Delete: `{source_entity}`~`{target_entity}` deleted successfully"
logger.info(message) logger.info(message)
await _delete_relation_done( await _persist_graph_updates(
relationships_vdb, chunk_entity_relation_graph, relation_chunks_storage relationships_vdb=relationships_vdb,
chunk_entity_relation_graph=chunk_entity_relation_graph,
relation_chunks_storage=relation_chunks_storage,
) )
return DeletionResult( return DeletionResult(
status="success", status="success",
@ -210,22 +239,6 @@ async def adelete_by_relation(
) )
async def _delete_relation_done(
relationships_vdb, chunk_entity_relation_graph, relation_chunks_storage=None
) -> None:
"""Callback after relation deletion is complete, ensures updates are persisted"""
storages = [relationships_vdb, chunk_entity_relation_graph]
if relation_chunks_storage is not None:
storages.append(relation_chunks_storage)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in storages # type: ignore
]
)
async def aedit_entity( async def aedit_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@ -560,12 +573,12 @@ async def aedit_entity(
) )
# 5. Save changes # 5. Save changes
await _edit_entity_done( await _persist_graph_updates(
entities_vdb, entities_vdb=entities_vdb,
relationships_vdb, relationships_vdb=relationships_vdb,
chunk_entity_relation_graph, chunk_entity_relation_graph=chunk_entity_relation_graph,
entity_chunks_storage, entity_chunks_storage=entity_chunks_storage,
relation_chunks_storage, relation_chunks_storage=relation_chunks_storage,
) )
logger.info(f"Entity Edit: `{entity_name}` successfully updated") logger.info(f"Entity Edit: `{entity_name}` successfully updated")
@ -580,28 +593,6 @@ async def aedit_entity(
raise raise
async def _edit_entity_done(
entities_vdb,
relationships_vdb,
chunk_entity_relation_graph,
entity_chunks_storage=None,
relation_chunks_storage=None,
) -> None:
"""Callback after entity editing is complete, ensures updates are persisted"""
storages = [entities_vdb, relationships_vdb, chunk_entity_relation_graph]
if entity_chunks_storage is not None:
storages.append(entity_chunks_storage)
if relation_chunks_storage is not None:
storages.append(relation_chunks_storage)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in storages # type: ignore
]
)
async def aedit_relation( async def aedit_relation(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@ -759,8 +750,10 @@ async def aedit_relation(
) )
# 5. Save changes # 5. Save changes
await _edit_relation_done( await _persist_graph_updates(
relationships_vdb, chunk_entity_relation_graph, relation_chunks_storage relationships_vdb=relationships_vdb,
chunk_entity_relation_graph=chunk_entity_relation_graph,
relation_chunks_storage=relation_chunks_storage,
) )
logger.info( logger.info(
@ -780,22 +773,6 @@ async def aedit_relation(
raise raise
async def _edit_relation_done(
relationships_vdb, chunk_entity_relation_graph, relation_chunks_storage=None
) -> None:
"""Callback after relation editing is complete, ensures updates are persisted"""
storages = [relationships_vdb, chunk_entity_relation_graph]
if relation_chunks_storage is not None:
storages.append(relation_chunks_storage)
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in storages # type: ignore
]
)
async def acreate_entity( async def acreate_entity(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@ -872,7 +849,7 @@ async def acreate_entity(
if entity_chunks_storage is not None: if entity_chunks_storage is not None:
source_id = node_data.get("source_id", "") source_id = node_data.get("source_id", "")
chunk_ids = [cid for cid in source_id.split(GRAPH_FIELD_SEP) if cid] chunk_ids = [cid for cid in source_id.split(GRAPH_FIELD_SEP) if cid]
if chunk_ids: if chunk_ids:
await entity_chunks_storage.upsert( await entity_chunks_storage.upsert(
{ {
@ -887,12 +864,12 @@ async def acreate_entity(
) )
# Save changes # Save changes
await _edit_entity_done( await _persist_graph_updates(
entities_vdb, entities_vdb=entities_vdb,
relationships_vdb, relationships_vdb=relationships_vdb,
chunk_entity_relation_graph, chunk_entity_relation_graph=chunk_entity_relation_graph,
entity_chunks_storage, entity_chunks_storage=entity_chunks_storage,
relation_chunks_storage, relation_chunks_storage=relation_chunks_storage,
) )
logger.info(f"Entity Create: '{entity_name}' successfully created") logger.info(f"Entity Create: '{entity_name}' successfully created")
@ -970,6 +947,10 @@ async def acreate_relation(
source_entity, target_entity, edge_data source_entity, target_entity, edge_data
) )
# Normalize entity order for undirected relation vector (ensures consistent key generation)
if source_entity > target_entity:
source_entity, target_entity = target_entity, source_entity
# Prepare content for embedding # Prepare content for embedding
description = edge_data.get("description", "") description = edge_data.get("description", "")
keywords = edge_data.get("keywords", "") keywords = edge_data.get("keywords", "")
@ -1008,10 +989,10 @@ async def acreate_relation(
# Normalize entity order for consistent key generation # Normalize entity order for consistent key generation
normalized_src, normalized_tgt = sorted([source_entity, target_entity]) normalized_src, normalized_tgt = sorted([source_entity, target_entity])
storage_key = make_relation_chunk_key(normalized_src, normalized_tgt) storage_key = make_relation_chunk_key(normalized_src, normalized_tgt)
source_id = edge_data.get("source_id", "") source_id = edge_data.get("source_id", "")
chunk_ids = [cid for cid in source_id.split(GRAPH_FIELD_SEP) if cid] chunk_ids = [cid for cid in source_id.split(GRAPH_FIELD_SEP) if cid]
if chunk_ids: if chunk_ids:
await relation_chunks_storage.upsert( await relation_chunks_storage.upsert(
{ {
@ -1026,8 +1007,10 @@ async def acreate_relation(
) )
# Save changes # Save changes
await _edit_relation_done( await _persist_graph_updates(
relationships_vdb, chunk_entity_relation_graph, relation_chunks_storage relationships_vdb=relationships_vdb,
chunk_entity_relation_graph=chunk_entity_relation_graph,
relation_chunks_storage=relation_chunks_storage,
) )
logger.info( logger.info(
@ -1055,11 +1038,14 @@ async def amerge_entities(
target_entity: str, target_entity: str,
merge_strategy: dict[str, str] = None, merge_strategy: dict[str, str] = None,
target_entity_data: dict[str, Any] = None, target_entity_data: dict[str, Any] = None,
entity_chunks_storage=None,
relation_chunks_storage=None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously merge multiple entities into one entity. """Asynchronously merge multiple entities into one entity.
Merges multiple source entities into a target entity, handling all relationships, Merges multiple source entities into a target entity, handling all relationships,
and updating both the knowledge graph and vector database. and updating both the knowledge graph and vector database.
Also merges chunk tracking information from entity_chunks_storage and relation_chunks_storage.
Args: Args:
chunk_entity_relation_graph: Graph storage instance chunk_entity_relation_graph: Graph storage instance
@ -1067,14 +1053,12 @@ async def amerge_entities(
relationships_vdb: Vector database storage for relationships relationships_vdb: Vector database storage for relationships
source_entities: List of source entity names to merge source_entities: List of source entity names to merge
target_entity: Name of the target entity after merging target_entity: Name of the target entity after merging
merge_strategy: Merge strategy configuration, e.g. {"description": "concatenate", "entity_type": "keep_first"} merge_strategy: Deprecated (Each field uses its own default strategy). If provided,
Supported strategies: customizations are applied but a warning is logged.
- "concatenate": Concatenate all values (for text fields)
- "keep_first": Keep the first non-empty value
- "keep_last": Keep the last non-empty value
- "join_unique": Join all unique values (for fields separated by delimiter)
target_entity_data: Dictionary of specific values to set for the target entity, target_entity_data: Dictionary of specific values to set for the target entity,
overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"} overriding any merged values, e.g. {"description": "custom description", "entity_type": "PERSON"}
entity_chunks_storage: Optional KV storage for tracking chunks that reference entities
relation_chunks_storage: Optional KV storage for tracking chunks that reference relations
Returns: Returns:
Dictionary containing the merged entity information Dictionary containing the merged entity information
@ -1083,18 +1067,22 @@ async def amerge_entities(
# Use graph database lock to ensure atomic graph and vector db operations # Use graph database lock to ensure atomic graph and vector db operations
async with graph_db_lock: async with graph_db_lock:
try: try:
# Default merge strategy # Default merge strategy for entities
default_strategy = { default_entity_merge_strategy = {
"description": "concatenate", "description": "concatenate",
"entity_type": "keep_first", "entity_type": "keep_first",
"source_id": "join_unique", "source_id": "join_unique",
"file_path": "join_unique",
} }
effective_entity_merge_strategy = default_entity_merge_strategy
merge_strategy = ( if merge_strategy:
default_strategy logger.warning(
if merge_strategy is None "Entity Merge: merge_strategy parameter is deprecated and will be ignored in a future release."
else {**default_strategy, **merge_strategy} )
) effective_entity_merge_strategy = {
**default_entity_merge_strategy,
**merge_strategy,
}
target_entity_data = ( target_entity_data = (
{} if target_entity_data is None else target_entity_data {} if target_entity_data is None else target_entity_data
) )
@ -1116,24 +1104,31 @@ async def amerge_entities(
await chunk_entity_relation_graph.get_node(target_entity) await chunk_entity_relation_graph.get_node(target_entity)
) )
logger.info( logger.info(
f"Target entity '{target_entity}' already exists, will merge data" "Entity Merge: target entity already exists, source and target entities will be merged"
) )
# 3. Merge entity data # 3. Merge entity data
merged_entity_data = _merge_entity_attributes( merged_entity_data = _merge_attributes(
list(source_entities_data.values()) list(source_entities_data.values())
+ ([existing_target_entity_data] if target_exists else []), + ([existing_target_entity_data] if target_exists else []),
merge_strategy, effective_entity_merge_strategy,
filter_none_only=False, # Use entity behavior: filter falsy values
) )
# Apply any explicitly provided target entity data (overrides merged data) # Apply any explicitly provided target entity data (overrides merged data)
for key, value in target_entity_data.items(): for key, value in target_entity_data.items():
merged_entity_data[key] = value merged_entity_data[key] = value
# 4. Get all relationships of the source entities # 4. Get all relationships of the source entities and target entity (if exists)
all_relations = [] all_relations = []
for entity_name in source_entities: entities_to_collect = source_entities.copy()
# Get all relationships of the source entities
# If target entity exists and not already in source_entities, add it
if target_exists and target_entity not in source_entities:
entities_to_collect.append(target_entity)
for entity_name in entities_to_collect:
# Get all relationships of the entities
edges = await chunk_entity_relation_graph.get_node_edges(entity_name) edges = await chunk_entity_relation_graph.get_node_edges(entity_name)
if edges: if edges:
for src, tgt in edges: for src, tgt in edges:
@ -1150,71 +1145,169 @@ async def amerge_entities(
await chunk_entity_relation_graph.upsert_node( await chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data target_entity, merged_entity_data
) )
logger.info(f"Created new target entity '{target_entity}'") logger.info(f"Entity Merge: created target '{target_entity}'")
else: else:
await chunk_entity_relation_graph.upsert_node( await chunk_entity_relation_graph.upsert_node(
target_entity, merged_entity_data target_entity, merged_entity_data
) )
logger.info(f"Updated existing target entity '{target_entity}'") logger.info(f"Entity Merge: Updated target '{target_entity}'")
# 6. Recreate all relationships, pointing to the target entity # 6. Recreate all relations pointing to the target entity in KG
# Also collect chunk tracking information in the same loop
relation_updates = {} # Track relationships that need to be merged relation_updates = {} # Track relationships that need to be merged
relations_to_delete = [] relations_to_delete = []
# Initialize chunk tracking variables
relation_chunk_tracking = {} # key: storage_key, value: list of chunk_ids
old_relation_keys_to_delete = []
for src, tgt, edge_data in all_relations: for src, tgt, edge_data in all_relations:
relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-")) relations_to_delete.append(compute_mdhash_id(src + tgt, prefix="rel-"))
relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-")) relations_to_delete.append(compute_mdhash_id(tgt + src, prefix="rel-"))
# Collect old chunk tracking key for deletion
if relation_chunks_storage is not None:
from .utils import make_relation_chunk_key
old_storage_key = make_relation_chunk_key(src, tgt)
old_relation_keys_to_delete.append(old_storage_key)
new_src = target_entity if src in source_entities else src new_src = target_entity if src in source_entities else src
new_tgt = target_entity if tgt in source_entities else tgt new_tgt = target_entity if tgt in source_entities else tgt
# Skip relationships between source entities to avoid self-loops # Skip relationships between source entities to avoid self-loops
if new_src == new_tgt: if new_src == new_tgt:
logger.info( logger.info(
f"Skipping relationship between source entities: {src} -> {tgt} to avoid self-loop" f"Entity Merge: skipping `{src}`~`{tgt}` to avoid self-loop"
) )
continue continue
# Check if the same relationship already exists # Normalize entity order for consistent duplicate detection (undirected relationships)
relation_key = f"{new_src}|{new_tgt}" normalized_src, normalized_tgt = sorted([new_src, new_tgt])
relation_key = f"{normalized_src}|{normalized_tgt}"
# Process chunk tracking for this relation
if relation_chunks_storage is not None:
storage_key = make_relation_chunk_key(
normalized_src, normalized_tgt
)
# Get chunk_ids from storage for this original relation
stored = await relation_chunks_storage.get_by_id(old_storage_key)
if stored is not None and isinstance(stored, dict):
chunk_ids = [cid for cid in stored.get("chunk_ids", []) if cid]
else:
# Fallback to source_id from graph
source_id = edge_data.get("source_id", "")
chunk_ids = [
cid for cid in source_id.split(GRAPH_FIELD_SEP) if cid
]
# Accumulate chunk_ids with ordered deduplication
if storage_key not in relation_chunk_tracking:
relation_chunk_tracking[storage_key] = []
existing_chunks = set(relation_chunk_tracking[storage_key])
for chunk_id in chunk_ids:
if chunk_id not in existing_chunks:
existing_chunks.add(chunk_id)
relation_chunk_tracking[storage_key].append(chunk_id)
if relation_key in relation_updates: if relation_key in relation_updates:
# Merge relationship data # Merge relationship data
existing_data = relation_updates[relation_key]["data"] existing_data = relation_updates[relation_key]["data"]
merged_relation = _merge_relation_attributes( merged_relation = _merge_attributes(
[existing_data, edge_data], [existing_data, edge_data],
{ {
"description": "concatenate", "description": "concatenate",
"keywords": "join_unique", "keywords": "join_unique_comma",
"source_id": "join_unique", "source_id": "join_unique",
"file_path": "join_unique",
"weight": "max", "weight": "max",
}, },
filter_none_only=True, # Use relation behavior: only filter None
) )
relation_updates[relation_key]["data"] = merged_relation relation_updates[relation_key]["data"] = merged_relation
logger.info( logger.info(
f"Merged duplicate relationship: {new_src} -> {new_tgt}" f"Entity Merge: deduplicating relation `{normalized_src}`~`{normalized_tgt}`"
) )
else: else:
relation_updates[relation_key] = { relation_updates[relation_key] = {
"src": new_src, "graph_src": new_src,
"tgt": new_tgt, "graph_tgt": new_tgt,
"norm_src": normalized_src,
"norm_tgt": normalized_tgt,
"data": edge_data.copy(), "data": edge_data.copy(),
} }
# Apply relationship updates # Apply relationship updates
for rel_data in relation_updates.values(): for rel_data in relation_updates.values():
await chunk_entity_relation_graph.upsert_edge( await chunk_entity_relation_graph.upsert_edge(
rel_data["src"], rel_data["tgt"], rel_data["data"] rel_data["graph_src"], rel_data["graph_tgt"], rel_data["data"]
) )
logger.info( logger.info(
f"Created or updated relationship: {rel_data['src']} -> {rel_data['tgt']}" f"Entity Merge: updating relation `{rel_data['graph_src']}`->`{rel_data['graph_tgt']}`"
) )
# Delete relationships records from vector database # Update relation chunk tracking storage
await relationships_vdb.delete(relations_to_delete) if relation_chunks_storage is not None and all_relations:
if old_relation_keys_to_delete:
await relation_chunks_storage.delete(old_relation_keys_to_delete)
if relation_chunk_tracking:
updates = {}
for storage_key, chunk_ids in relation_chunk_tracking.items():
updates[storage_key] = {
"chunk_ids": chunk_ids,
"count": len(chunk_ids),
}
await relation_chunks_storage.upsert(updates)
logger.info(
f"Entity Merge: merged chunk tracking for {len(updates)} relations"
)
# 7. Update relationship vector representations
logger.info(
f"Entity Merge: deleting {len(relations_to_delete)} relations from vdb"
)
await relationships_vdb.delete(relations_to_delete)
for rel_data in relation_updates.values():
edge_data = rel_data["data"]
normalized_src = rel_data["norm_src"]
normalized_tgt = rel_data["norm_tgt"]
description = edge_data.get("description", "")
keywords = edge_data.get("keywords", "")
source_id = edge_data.get("source_id", "")
weight = float(edge_data.get("weight", 1.0))
# Use normalized order for content and relation ID
content = (
f"{keywords}\t{normalized_src}\n{normalized_tgt}\n{description}"
)
relation_id = compute_mdhash_id(
normalized_src + normalized_tgt, prefix="rel-"
)
relation_data_for_vdb = {
relation_id: {
"content": content,
"src_id": normalized_src,
"tgt_id": normalized_tgt,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
await relationships_vdb.upsert(relation_data_for_vdb)
logger.info( logger.info(
f"Deleted {len(relations_to_delete)} relation records for entity from vector database" f"Entity Merge: updating vdb `{normalized_src}`~`{normalized_tgt}`"
) )
# 7. Update entity vector representation # 8. Update entity vector representation
description = merged_entity_data.get("description", "") description = merged_entity_data.get("description", "")
source_id = merged_entity_data.get("source_id", "") source_id = merged_entity_data.get("source_id", "")
entity_type = merged_entity_data.get("entity_type", "") entity_type = merged_entity_data.get("entity_type", "")
@ -1230,63 +1323,91 @@ async def amerge_entities(
"entity_type": entity_type, "entity_type": entity_type,
} }
} }
await entities_vdb.upsert(entity_data_for_vdb) await entities_vdb.upsert(entity_data_for_vdb)
logger.info(f"Entity Merge: updating vdb `{target_entity}`")
# 8. Update relationship vector representations # 9. Merge entity chunk tracking (source entities first, then target entity)
for rel_data in relation_updates.values(): if entity_chunks_storage is not None:
src = rel_data["src"] all_chunk_id_lists = []
tgt = rel_data["tgt"]
edge_data = rel_data["data"]
description = edge_data.get("description", "") # Build list of entities to process (source entities first, then target entity)
keywords = edge_data.get("keywords", "") entities_to_process = []
source_id = edge_data.get("source_id", "")
weight = float(edge_data.get("weight", 1.0))
content = f"{keywords}\t{src}\n{tgt}\n{description}" # Add source entities first (excluding target if it's already in source list)
relation_id = compute_mdhash_id(src + tgt, prefix="rel-") for entity_name in source_entities:
if entity_name != target_entity:
entities_to_process.append(entity_name)
relation_data_for_vdb = { # Add target entity last (if it exists)
relation_id: { if target_exists:
"content": content, entities_to_process.append(target_entity)
"src_id": src,
"tgt_id": tgt,
"source_id": source_id,
"description": description,
"keywords": keywords,
"weight": weight,
}
}
await relationships_vdb.upsert(relation_data_for_vdb) # Process all entities in order with unified logic
for entity_name in entities_to_process:
stored = await entity_chunks_storage.get_by_id(entity_name)
if stored and isinstance(stored, dict):
chunk_ids = [cid for cid in stored.get("chunk_ids", []) if cid]
if chunk_ids:
all_chunk_id_lists.append(chunk_ids)
# 9. Delete source entities # Merge chunk_ids with ordered deduplication (preserves order, source entities first)
merged_chunk_ids = []
seen = set()
for chunk_id_list in all_chunk_id_lists:
for chunk_id in chunk_id_list:
if chunk_id not in seen:
seen.add(chunk_id)
merged_chunk_ids.append(chunk_id)
# Delete source entities' chunk tracking records
entity_keys_to_delete = [
e for e in source_entities if e != target_entity
]
if entity_keys_to_delete:
await entity_chunks_storage.delete(entity_keys_to_delete)
# Update target entity's chunk tracking
if merged_chunk_ids:
await entity_chunks_storage.upsert(
{
target_entity: {
"chunk_ids": merged_chunk_ids,
"count": len(merged_chunk_ids),
}
}
)
logger.info(
f"Entity Merge: find {len(merged_chunk_ids)} chunks related to '{target_entity}'"
)
# 10. Delete source entities
for entity_name in source_entities: for entity_name in source_entities:
if entity_name == target_entity: if entity_name == target_entity:
logger.info( logger.warning(
f"Skipping deletion of '{entity_name}' as it's also the target entity" f"Entity Merge: source entity'{entity_name}' is same as target entity"
) )
continue continue
# Delete entity node from knowledge graph logger.info(f"Entity Merge: deleting '{entity_name}' from KG and vdb")
# Delete entity node and related edges from knowledge graph
await chunk_entity_relation_graph.delete_node(entity_name) await chunk_entity_relation_graph.delete_node(entity_name)
# Delete entity record from vector database # Delete entity record from vector database
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
await entities_vdb.delete([entity_id]) await entities_vdb.delete([entity_id])
logger.info( # 11. Save changes
f"Deleted source entity '{entity_name}' and its vector embedding from database" await _persist_graph_updates(
) entities_vdb=entities_vdb,
relationships_vdb=relationships_vdb,
# 10. Save changes chunk_entity_relation_graph=chunk_entity_relation_graph,
await _merge_entities_done( entity_chunks_storage=entity_chunks_storage,
entities_vdb, relationships_vdb, chunk_entity_relation_graph relation_chunks_storage=relation_chunks_storage,
) )
logger.info( logger.info(
f"Successfully merged {len(source_entities)} entities into '{target_entity}'" f"Entity Merge: successfully merged {len(source_entities)} entities into '{target_entity}'"
) )
return await get_entity_info( return await get_entity_info(
chunk_entity_relation_graph, chunk_entity_relation_graph,
@ -1300,81 +1421,45 @@ async def amerge_entities(
raise raise
def _merge_entity_attributes( def _merge_attributes(
entity_data_list: list[dict[str, Any]], merge_strategy: dict[str, str] data_list: list[dict[str, Any]],
merge_strategy: dict[str, str],
filter_none_only: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Merge attributes from multiple entities. """Merge attributes from multiple entities or relationships.
This unified function handles merging of both entity and relationship attributes,
applying different merge strategies per field.
Args: Args:
entity_data_list: List of dictionaries containing entity data data_list: List of dictionaries containing entity or relationship data
merge_strategy: Merge strategy for each field merge_strategy: Merge strategy for each field. Supported strategies:
- "concatenate": Join all values with GRAPH_FIELD_SEP
- "keep_first": Keep the first non-empty value
- "keep_last": Keep the last non-empty value
- "join_unique": Join unique items separated by GRAPH_FIELD_SEP
- "join_unique_comma": Join unique items separated by comma and space
- "max": Keep the maximum numeric value (for numeric fields)
filter_none_only: If True, only filter None values (keep empty strings, 0, etc.).
If False, filter all falsy values. Default is False for backward compatibility.
Returns: Returns:
Dictionary containing merged entity data Dictionary containing merged data
""" """
merged_data = {} merged_data = {}
# Collect all possible keys # Collect all possible keys
all_keys = set() all_keys = set()
for data in entity_data_list: for data in data_list:
all_keys.update(data.keys()) all_keys.update(data.keys())
# Merge values for each key # Merge values for each key
for key in all_keys: for key in all_keys:
# Get all values for this key # Get all values for this key based on filtering mode
values = [data.get(key) for data in entity_data_list if data.get(key)] if filter_none_only:
values = [data.get(key) for data in data_list if data.get(key) is not None]
if not values:
continue
# Merge values according to strategy
strategy = merge_strategy.get(key, "keep_first")
if strategy == "concatenate":
merged_data[key] = "\n\n".join(values)
elif strategy == "keep_first":
merged_data[key] = values[0]
elif strategy == "keep_last":
merged_data[key] = values[-1]
elif strategy == "join_unique":
# Handle fields separated by GRAPH_FIELD_SEP
unique_items = set()
for value in values:
items = value.split(GRAPH_FIELD_SEP)
unique_items.update(items)
merged_data[key] = GRAPH_FIELD_SEP.join(unique_items)
else: else:
# Default strategy values = [data.get(key) for data in data_list if data.get(key)]
merged_data[key] = values[0]
return merged_data
def _merge_relation_attributes(
relation_data_list: list[dict[str, Any]], merge_strategy: dict[str, str]
) -> dict[str, Any]:
"""Merge attributes from multiple relationships.
Args:
relation_data_list: List of dictionaries containing relationship data
merge_strategy: Merge strategy for each field
Returns:
Dictionary containing merged relationship data
"""
merged_data = {}
# Collect all possible keys
all_keys = set()
for data in relation_data_list:
all_keys.update(data.keys())
# Merge values for each key
for key in all_keys:
# Get all values for this key
values = [
data.get(key) for data in relation_data_list if data.get(key) is not None
]
if not values: if not values:
continue continue
@ -1383,7 +1468,8 @@ def _merge_relation_attributes(
strategy = merge_strategy.get(key, "keep_first") strategy = merge_strategy.get(key, "keep_first")
if strategy == "concatenate": if strategy == "concatenate":
merged_data[key] = "\n\n".join(str(v) for v in values) # Convert all values to strings and join with GRAPH_FIELD_SEP
merged_data[key] = GRAPH_FIELD_SEP.join(str(v) for v in values)
elif strategy == "keep_first": elif strategy == "keep_first":
merged_data[key] = values[0] merged_data[key] = values[0]
elif strategy == "keep_last": elif strategy == "keep_last":
@ -1395,35 +1481,27 @@ def _merge_relation_attributes(
items = str(value).split(GRAPH_FIELD_SEP) items = str(value).split(GRAPH_FIELD_SEP)
unique_items.update(items) unique_items.update(items)
merged_data[key] = GRAPH_FIELD_SEP.join(unique_items) merged_data[key] = GRAPH_FIELD_SEP.join(unique_items)
elif strategy == "join_unique_comma":
# Handle fields separated by comma, join unique items with comma
unique_items = set()
for value in values:
items = str(value).split(",")
unique_items.update(item.strip() for item in items if item.strip())
merged_data[key] = ",".join(sorted(unique_items))
elif strategy == "max": elif strategy == "max":
# For numeric fields like weight # For numeric fields like weight
try: try:
merged_data[key] = max(float(v) for v in values) merged_data[key] = max(float(v) for v in values)
except (ValueError, TypeError): except (ValueError, TypeError):
# Fallback to first value if conversion fails
merged_data[key] = values[0] merged_data[key] = values[0]
else: else:
# Default strategy # Default strategy: keep first value
merged_data[key] = values[0] merged_data[key] = values[0]
return merged_data return merged_data
async def _merge_entities_done(
entities_vdb, relationships_vdb, chunk_entity_relation_graph
) -> None:
"""Callback after entity merging is complete, ensures updates are persisted"""
await asyncio.gather(
*[
cast(StorageNameSpace, storage_inst).index_done_callback()
for storage_inst in [ # type: ignore
entities_vdb,
relationships_vdb,
chunk_entity_relation_graph,
]
]
)
async def get_entity_info( async def get_entity_info(
chunk_entity_relation_graph, chunk_entity_relation_graph,
entities_vdb, entities_vdb,
@ -1458,7 +1536,18 @@ async def get_relation_info(
tgt_entity: str, tgt_entity: str,
include_vector_data: bool = False, include_vector_data: bool = False,
) -> dict[str, str | None | dict[str, str]]: ) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of a relationship""" """
Get detailed information of a relationship between two entities.
Relationship is unidirectional, swap src_entity and tgt_entity does not change the relationship.
Args:
src_entity: Source entity name
tgt_entity: Target entity name
include_vector_data: Whether to include vector database information
Returns:
Dictionary containing relationship information
"""
# Get information from the graph # Get information from the graph
edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity) edge_data = await chunk_entity_relation_graph.get_edge(src_entity, tgt_entity)