diff --git a/lightrag/base.py b/lightrag/base.py index 9a5fbeb6..9f82c6eb 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -654,6 +654,22 @@ class BaseGraphStorage(StorageNameSpace, ABC): indicating whether the graph was truncated due to max_nodes limit """ + @abstractmethod + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + + @abstractmethod + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + class DocStatus(str, Enum): """Document processing status""" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index eeca53ed..f6f77597 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -393,6 +393,35 @@ class NetworkXStorage(BaseGraphStorage): matching_edges.append(edge_data_with_nodes) return matching_edges + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + graph = await self._get_graph() + all_nodes = [] + for node_id, node_data in graph.nodes(data=True): + node_data_with_id = node_data.copy() + node_data_with_id["id"] = node_id + all_nodes.append(node_data_with_id) + return all_nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + graph = await self._get_graph() + all_edges = [] + for u, v, edge_data in graph.edges(data=True): + edge_data_with_nodes = edge_data.copy() + edge_data_with_nodes["source"] = u + edge_data_with_nodes["target"] = v + all_edges.append(edge_data_with_nodes) + return all_edges + async def index_done_callback(self) -> bool: """Save data to disk""" async with self._storage_lock: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ae42e5df..73386b0b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -453,14 +453,26 @@ class LightRAG: embedding_func=self.embedding_func, ) + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + namespace=NameSpace.KV_STORE_TEXT_CHUNKS, + workspace=self.workspace, + embedding_func=self.embedding_func, + ) + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=NameSpace.KV_STORE_FULL_DOCS, workspace=self.workspace, embedding_func=self.embedding_func, ) - self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore - namespace=NameSpace.KV_STORE_TEXT_CHUNKS, + self.full_entities: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + namespace=NameSpace.KV_STORE_FULL_ENTITIES, + workspace=self.workspace, + embedding_func=self.embedding_func, + ) + + self.full_relations: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore + namespace=NameSpace.KV_STORE_FULL_RELATIONS, workspace=self.workspace, embedding_func=self.embedding_func, ) @@ -553,6 +565,8 @@ class LightRAG: for storage in ( self.full_docs, self.text_chunks, + self.full_entities, + self.full_relations, self.entities_vdb, self.relationships_vdb, self.chunks_vdb, @@ -576,6 +590,8 @@ class LightRAG: for storage in ( self.full_docs, self.text_chunks, + self.full_entities, + self.full_relations, self.entities_vdb, self.relationships_vdb, self.chunks_vdb, @@ -591,6 +607,159 @@ class LightRAG: self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") + async def _check_and_migrate_data(self): + """Check if data migration is needed and perform migration if necessary""" + try: + # Check if migration is needed: + # 1. chunk_entity_relation_graph has entities and relations (count > 0) + # 2. full_entities and full_relations are empty + + # Get all entity labels from graph + all_entity_labels = await self.chunk_entity_relation_graph.get_all_labels() + + if not all_entity_labels: + logger.debug("No entities found in graph, skipping migration check") + return + + # Check if full_entities and full_relations are empty + # Get all processed documents to check their entity/relation data + try: + processed_docs = await self.doc_status.get_docs_by_status( + DocStatus.PROCESSED + ) + + if not processed_docs: + logger.debug("No processed documents found, skipping migration") + return + + # Check first few documents to see if they have full_entities/full_relations data + migration_needed = True + checked_count = 0 + max_check = min(5, len(processed_docs)) # Check up to 5 documents + + for doc_id in list(processed_docs.keys())[:max_check]: + checked_count += 1 + entity_data = await self.full_entities.get_by_id(doc_id) + relation_data = await self.full_relations.get_by_id(doc_id) + + if entity_data or relation_data: + migration_needed = False + break + + if not migration_needed: + logger.debug( + "Full entities/relations data already exists, no migration needed" + ) + return + + logger.info( + f"Data migration needed: found {len(all_entity_labels)} entities in graph but no full_entities/full_relations data" + ) + + # Perform migration + await self._migrate_entity_relation_data(processed_docs) + + except Exception as e: + logger.error(f"Error during migration check: {e}") + # Don't raise the error, just log it to avoid breaking initialization + + except Exception as e: + logger.error(f"Error in data migration check: {e}") + # Don't raise the error to avoid breaking initialization + + async def _migrate_entity_relation_data(self, processed_docs: dict): + """Migrate existing entity and relation data to full_entities and full_relations storage""" + logger.info(f"Starting data migration for {len(processed_docs)} documents") + + # Create mapping from chunk_id to doc_id + chunk_to_doc = {} + for doc_id, doc_status in processed_docs.items(): + chunk_ids = ( + doc_status.chunks_list + if hasattr(doc_status, "chunks_list") and doc_status.chunks_list + else [] + ) + for chunk_id in chunk_ids: + chunk_to_doc[chunk_id] = doc_id + + # Initialize document entity and relation mappings + doc_entities = {} # doc_id -> set of entity_names + doc_relations = {} # doc_id -> set of relation_pairs (as tuples) + + # Get all nodes and edges from graph + all_nodes = await self.chunk_entity_relation_graph.get_all_nodes() + all_edges = await self.chunk_entity_relation_graph.get_all_edges() + + # Process all nodes once + for node in all_nodes: + if "source_id" in node: + entity_id = node.get("entity_id") or node.get("id") + if not entity_id: + continue + + # Get chunk IDs from source_id + source_ids = node["source_id"].split(GRAPH_FIELD_SEP) + + # Find which documents this entity belongs to + for chunk_id in source_ids: + doc_id = chunk_to_doc.get(chunk_id) + if doc_id: + if doc_id not in doc_entities: + doc_entities[doc_id] = set() + doc_entities[doc_id].add(entity_id) + + # Process all edges once + for edge in all_edges: + if "source_id" in edge: + src = edge.get("source") + tgt = edge.get("target") + if not src or not tgt: + continue + + # Get chunk IDs from source_id + source_ids = edge["source_id"].split(GRAPH_FIELD_SEP) + + # Find which documents this relation belongs to + for chunk_id in source_ids: + doc_id = chunk_to_doc.get(chunk_id) + if doc_id: + if doc_id not in doc_relations: + doc_relations[doc_id] = set() + # Use tuple for set operations, convert to list later + doc_relations[doc_id].add((src, tgt)) + + # Store the results in full_entities and full_relations + migration_count = 0 + + # Store entities + if doc_entities: + entities_data = {} + for doc_id, entity_set in doc_entities.items(): + entities_data[doc_id] = {"entity_names": list(entity_set)} + await self.full_entities.upsert(entities_data) + + # Store relations + if doc_relations: + relations_data = {} + for doc_id, relation_set in doc_relations.items(): + # Convert tuples back to lists + relations_data[doc_id] = { + "relation_pairs": [list(pair) for pair in relation_set] + } + await self.full_relations.upsert(relations_data) + + migration_count = len( + set(list(doc_entities.keys()) + list(doc_relations.keys())) + ) + + # Persist the migrated data + await self.full_entities.index_done_callback() + await self.full_relations.index_done_callback() + + logger.info( + f"Data migration completed: migrated {migration_count} documents with entities/relations" + ) + async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels() return text @@ -1229,6 +1398,9 @@ class LightRAG: entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, global_config=asdict(self), + full_entities_storage=self.full_entities, + full_relations_storage=self.full_relations, + doc_id=doc_id, pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, @@ -1401,6 +1573,8 @@ class LightRAG: self.full_docs, self.doc_status, self.text_chunks, + self.full_entities, + self.full_relations, self.llm_response_cache, self.entities_vdb, self.relationships_vdb, @@ -1959,21 +2133,54 @@ class LightRAG: graph_db_lock = get_graph_db_lock(enable_logging=False) async with graph_db_lock: try: - # Get all affected nodes and edges in batch - # logger.info( - # f"Analyzing affected entities and relationships for {len(chunk_ids)} chunks" - # ) - affected_nodes = ( - await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids( - list(chunk_ids) - ) - ) + # Get affected entities and relations from full_entities and full_relations storage + doc_entities_data = await self.full_entities.get_by_id(doc_id) + doc_relations_data = await self.full_relations.get_by_id(doc_id) - affected_edges = ( - await self.chunk_entity_relation_graph.get_edges_by_chunk_ids( - list(chunk_ids) + affected_nodes = [] + affected_edges = [] + + # Get entity data from graph storage using entity names from full_entities + if doc_entities_data and "entity_names" in doc_entities_data: + entity_names = doc_entities_data["entity_names"] + # get_nodes_batch returns dict[str, dict], need to convert to list[dict] + nodes_dict = ( + await self.chunk_entity_relation_graph.get_nodes_batch( + entity_names + ) ) - ) + for entity_name in entity_names: + node_data = nodes_dict.get(entity_name) + if node_data: + # Ensure compatibility with existing logic that expects "id" field + if "id" not in node_data: + node_data["id"] = entity_name + affected_nodes.append(node_data) + + # Get relation data from graph storage using relation pairs from full_relations + if doc_relations_data and "relation_pairs" in doc_relations_data: + relation_pairs = doc_relations_data["relation_pairs"] + edge_pairs_dicts = [ + {"src": pair[0], "tgt": pair[1]} for pair in relation_pairs + ] + # get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict] + edges_dict = ( + await self.chunk_entity_relation_graph.get_edges_batch( + edge_pairs_dicts + ) + ) + + for pair in relation_pairs: + src, tgt = pair[0], pair[1] + edge_key = (src, tgt) + edge_data = edges_dict.get(edge_key) + if edge_data: + # Ensure compatibility with existing logic that expects "source" and "target" fields + if "source" not in edge_data: + edge_data["source"] = src + if "target" not in edge_data: + edge_data["target"] = tgt + affected_edges.append(edge_data) except Exception as e: logger.error(f"Failed to analyze affected graph elements: {e}") @@ -2125,7 +2332,17 @@ class LightRAG: f"Failed to rebuild knowledge graph: {e}" ) from e - # 9. Delete original document and status + # 9. Delete from full_entities and full_relations storage + try: + await self.full_entities.delete([doc_id]) + await self.full_relations.delete([doc_id]) + except Exception as e: + logger.error(f"Failed to delete from full_entities/full_relations: {e}") + raise Exception( + f"Failed to delete from full_entities/full_relations: {e}" + ) from e + + # 10. Delete original document and status try: await self.full_docs.delete([doc_id]) await self.doc_status.delete([doc_id]) diff --git a/lightrag/namespace.py b/lightrag/namespace.py index 657d65ac..5c042713 100644 --- a/lightrag/namespace.py +++ b/lightrag/namespace.py @@ -7,6 +7,8 @@ class NameSpace: KV_STORE_FULL_DOCS = "full_docs" KV_STORE_TEXT_CHUNKS = "text_chunks" KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache" + KV_STORE_FULL_ENTITIES = "full_entities" + KV_STORE_FULL_RELATIONS = "full_relations" VECTOR_STORE_ENTITIES = "entities" VECTOR_STORE_RELATIONSHIPS = "relationships" diff --git a/lightrag/operate.py b/lightrag/operate.py index 45153bc5..ca21881b 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -504,9 +504,6 @@ async def _rebuild_knowledge_from_chunks( # Re-raise the exception to notify the caller raise task.exception() - # If all tasks completed successfully, collect results - # (No need to collect results since these tasks don't return values) - # Final status report status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." if failed_entities_count > 0 or failed_relationships_count > 0: @@ -1024,6 +1021,7 @@ async def _merge_edges_then_upsert( pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, + added_entities: list = None, # New parameter to track entities added during edge processing ): if src_id == tgt_id: return None @@ -1105,17 +1103,27 @@ async def _merge_edges_then_upsert( for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): - await knowledge_graph_inst.upsert_node( - need_insert_id, - node_data={ - "entity_id": need_insert_id, - "source_id": source_id, - "description": description, + node_data = { + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + "created_at": int(time.time()), + } + await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) + + # Track entities added during edge processing + if added_entities is not None: + entity_data = { + "entity_name": need_insert_id, "entity_type": "UNKNOWN", + "description": description, + "source_id": source_id, "file_path": file_path, "created_at": int(time.time()), - }, - ) + } + added_entities.append(entity_data) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] @@ -1178,6 +1186,9 @@ async def merge_nodes_and_edges( entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], + full_entities_storage: BaseKVStorage = None, + full_relations_storage: BaseKVStorage = None, + doc_id: str = None, pipeline_status: dict = None, pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, @@ -1185,7 +1196,12 @@ async def merge_nodes_and_edges( total_files: int = 0, file_path: str = "unknown_source", ) -> None: - """Merge nodes and edges from extraction results + """Two-phase merge: process all entities first, then all relationships + + This approach ensures data consistency by: + 1. Phase 1: Process all entities concurrently + 2. Phase 2: Process all relationships concurrently (may add missing entities) + 3. Phase 3: Update full_entities and full_relations storage with final results Args: chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships @@ -1193,9 +1209,15 @@ async def merge_nodes_and_edges( entity_vdb: Entity vector database relationships_vdb: Relationship vector database global_config: Global configuration + full_entities_storage: Storage for document entity lists + full_relations_storage: Storage for document relation lists + doc_id: Document ID for storage indexing pipeline_status: Pipeline status dictionary pipeline_status_lock: Lock for pipeline status llm_response_cache: LLM response cache + current_file_number: Current file number for logging + total_files: Total files for logging + file_path: File path for logging """ # Collect all nodes and edges from all chunks @@ -1212,11 +1234,9 @@ async def merge_nodes_and_edges( sorted_edge_key = tuple(sorted(edge_key)) all_edges[sorted_edge_key].extend(edges) - # Centralized processing of all nodes and edges total_entities_count = len(all_nodes) total_relations_count = len(all_edges) - # Merge nodes and edges log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" logger.info(log_message) async with pipeline_status_lock: @@ -1227,8 +1247,8 @@ async def merge_nodes_and_edges( graph_max_async = global_config.get("llm_model_max_async", 4) * 2 semaphore = asyncio.Semaphore(graph_max_async) - # Process and update all entities and relationships in parallel - log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations (async: {graph_max_async})" + # ===== Phase 1: Process all entities concurrently ===== + log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message @@ -1263,18 +1283,53 @@ async def merge_nodes_and_edges( await entity_vdb.upsert(data_for_vdb) return entity_data + # Create entity processing tasks + entity_tasks = [] + for entity_name, entities in all_nodes.items(): + task = asyncio.create_task(_locked_process_entity_name(entity_name, entities)) + entity_tasks.append(task) + + # Execute entity tasks with error handling + processed_entities = [] + if entity_tasks: + done, pending = await asyncio.wait( + entity_tasks, return_when=asyncio.FIRST_EXCEPTION + ) + + # Check if any task raised an exception + for task in done: + if task.exception(): + # If a task failed, cancel all pending tasks + for pending_task in pending: + pending_task.cancel() + # Wait for cancellation to complete + if pending: + await asyncio.wait(pending) + # Re-raise the exception to notify the caller + raise task.exception() + + # If all tasks completed successfully, collect results + processed_entities = [task.result() for task in entity_tasks] + + # ===== Phase 2: Process all relationships concurrently ===== + log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + async def _locked_process_edges(edge_key, edges): async with semaphore: workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" - # Sort the edge_key components to ensure consistent lock key generation sorted_edge_key = sorted([edge_key[0], edge_key[1]]) - # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}") + async with get_storage_keyed_lock( sorted_edge_key, namespace=namespace, enable_logging=False, ): + added_entities = [] # Track entities added during edge processing edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], @@ -1284,9 +1339,11 @@ async def merge_nodes_and_edges( pipeline_status, pipeline_status_lock, llm_response_cache, + added_entities, # Pass list to collect added entities ) + if edge_data is None: - return None + return None, [] if relationships_vdb is not None: data_for_vdb = { @@ -1303,50 +1360,106 @@ async def merge_nodes_and_edges( } } await relationships_vdb.upsert(data_for_vdb) - return edge_data + return edge_data, added_entities - # Create a single task queue for both entities and edges - tasks = [] + # Create relationship processing tasks + edge_tasks = [] + for edge_key, edges in all_edges.items(): + task = asyncio.create_task(_locked_process_edges(edge_key, edges)) + edge_tasks.append(task) - # Add entity processing tasks - for entity_name, entities in all_nodes.items(): - tasks.append( - asyncio.create_task(_locked_process_entity_name(entity_name, entities)) + # Execute relationship tasks with error handling + processed_edges = [] + all_added_entities = [] + + if edge_tasks: + done, pending = await asyncio.wait( + edge_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Add edge processing tasks - for edge_key, edges in all_edges.items(): - tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges))) + # Check if any task raised an exception + for task in done: + if task.exception(): + # If a task failed, cancel all pending tasks + for pending_task in pending: + pending_task.cancel() + # Wait for cancellation to complete + if pending: + await asyncio.wait(pending) + # Re-raise the exception to notify the caller + raise task.exception() - # Check if there are any tasks to process - if not tasks: - log_message = f"No entities or relationships to process for {file_path}" - logger.info(log_message) - if pipeline_status_lock is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - return + # If all tasks completed successfully, collect results + for task in edge_tasks: + edge_data, added_entities = task.result() + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) - # Execute all tasks in parallel with semaphore control and early failure detection - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + # ===== Phase 3: Update full_entities and full_relations storage ===== + if full_entities_storage and full_relations_storage and doc_id: + try: + # Merge all entities: original entities + entities added during edge processing + final_entity_names = set() - # Check if any task raised an exception - for task in done: - if task.exception(): - # If a task failed, cancel all pending tasks - for pending_task in pending: - pending_task.cancel() + # Add original processed entities + for entity_data in processed_entities: + if entity_data and entity_data.get("entity_name"): + final_entity_names.add(entity_data["entity_name"]) - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) + # Add entities that were added during relationship processing + for added_entity in all_added_entities: + if added_entity and added_entity.get("entity_name"): + final_entity_names.add(added_entity["entity_name"]) - # Re-raise the exception to notify the caller - raise task.exception() + # Collect all relation pairs + final_relation_pairs = set() + for edge_data in processed_edges: + if edge_data: + src_id = edge_data.get("src_id") + tgt_id = edge_data.get("tgt_id") + if src_id and tgt_id: + relation_pair = tuple(sorted([src_id, tgt_id])) + final_relation_pairs.add(relation_pair) - # If all tasks completed successfully, collect results - # (No need to collect results since these tasks don't return values) + # Update storage + if final_entity_names: + await full_entities_storage.upsert( + { + doc_id: { + "entity_names": list(final_entity_names), + "count": len(final_entity_names), + } + } + ) + + if final_relation_pairs: + await full_relations_storage.upsert( + { + doc_id: { + "relation_pairs": [ + list(pair) for pair in final_relation_pairs + ], + "count": len(final_relation_pairs), + } + } + ) + + logger.debug( + f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations" + ) + + except Exception as e: + logger.error( + f"Failed to update entity-relation index for document {doc_id}: {e}" + ) + # Don't raise exception to avoid affecting main flow + + log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} added entities, {len(processed_edges)} relations" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def extract_entities(