From 9878dc7f51dc7c6dc1a448779d333943dabdbc90 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 5 Jul 2025 00:32:55 +0800 Subject: [PATCH] fix: ensure Milvus collections are loaded before operations - Resolves "collection not loaded" MilvusException errors --- lightrag/kg/milvus_impl.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 6cffae88..eecf679a 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -539,6 +539,23 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) raise + def _ensure_collection_loaded(self): + """Ensure the collection is loaded into memory for search operations""" + try: + # Check if collection exists first + if not self._client.has_collection(self.namespace): + logger.error(f"Collection {self.namespace} does not exist") + raise ValueError(f"Collection {self.namespace} does not exist") + + # Load the collection if it's not already loaded + # In Milvus, collections need to be loaded before they can be searched + self._client.load_collection(self.namespace) + logger.debug(f"Collection {self.namespace} loaded successfully") + + except Exception as e: + logger.error(f"Failed to load collection {self.namespace}: {e}") + raise + def _create_collection_if_not_exist(self): """Create collection if not exists and check existing collection compatibility""" @@ -565,6 +582,8 @@ class MilvusVectorDBStorage(BaseVectorStorage): f"Collection '{self.namespace}' confirmed to exist, validating compatibility..." ) self._validate_collection_compatibility() + # Ensure the collection is loaded after validation + self._ensure_collection_loaded() return except Exception as describe_error: logger.warning( @@ -587,6 +606,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Then create indexes self._create_indexes_after_collection() + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully created Milvus collection: {self.namespace}") except Exception as e: @@ -615,6 +637,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): collection_name=self.namespace, schema=schema ) self._create_indexes_after_collection() + + # Load the newly created collection + self._ensure_collection_loaded() + logger.info(f"Successfully force-created collection {self.namespace}") except Exception as create_error: @@ -670,6 +696,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not data: return + # Ensure collection is loaded before upserting + self._ensure_collection_loaded() + import time current_time = int(time.time()) @@ -700,6 +729,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def query( self, query: str, top_k: int, ids: list[str] | None = None ) -> list[dict[str, Any]]: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query @@ -764,6 +796,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): entity_name: The name of the entity whose relations should be deleted """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Search for relations where entity is either source or target expr = f'src_id == "{entity_name}" or tgt_id == "{entity_name}"' @@ -802,6 +837,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: + # Ensure collection is loaded before deleting + self._ensure_collection_loaded() + # Delete vectors by IDs result = self._client.delete(collection_name=self.namespace, pks=ids) @@ -825,6 +863,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): The vector data if found, or None if not found """ try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"] @@ -856,6 +897,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return [] try: + # Ensure collection is loaded before querying + self._ensure_collection_loaded() + # Include all meta_fields (created_at is now always included) plus id output_fields = list(self.meta_fields) + ["id"]