From 6201c7fb56bf200c3fef2164cf28888ad8cd85ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 4 Jul 2025 21:42:10 +0800 Subject: [PATCH] Refactoring Milvus implementation --- lightrag/kg/milvus_impl.py | 652 +++++++++++++++++++++++++++++++++++-- 1 file changed, 619 insertions(+), 33 deletions(-) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 70a793f7..6cffae88 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"): pm.install("pymilvus") import configparser -from pymilvus import MilvusClient # type: ignore +from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -24,16 +24,605 @@ config.read("config.ini", "utf-8") @final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): - @staticmethod - def create_collection_if_not_exist( - client: MilvusClient, collection_name: str, **kwargs - ): - if client.has_collection(collection_name): - return - client.create_collection( - collection_name, max_length=64, id_type="string", **kwargs + def _create_schema_for_namespace(self) -> CollectionSchema: + """Create schema based on the current instance's namespace""" + + # Get vector dimension from embedding_func + dimension = self.embedding_func.embedding_dim + + # Base fields (common to all collections) + base_fields = [ + FieldSchema( + name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True + ), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name="created_at", dtype=DataType.INT64), + ] + + # Determine specific fields based on namespace + if "entities" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="entity_name", + dtype=DataType.VARCHAR, + max_length=256, + nullable=True, + ), + FieldSchema( + name="entity_type", + dtype=DataType.VARCHAR, + max_length=64, + nullable=True, + ), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG entities vector storage" + + elif "relationships" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True + ), + FieldSchema( + name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True + ), + FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG relationships vector storage" + + elif "chunks" in self.namespace.lower(): + specific_fields = [ + FieldSchema( + name="full_doc_id", + dtype=DataType.VARCHAR, + max_length=64, + nullable=True, + ), + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG chunks vector storage" + + else: + # Default generic schema (backward compatibility) + specific_fields = [ + FieldSchema( + name="file_path", + dtype=DataType.VARCHAR, + max_length=512, + nullable=True, + ), + ] + description = "LightRAG generic vector storage" + + # Merge all fields + all_fields = base_fields + specific_fields + + return CollectionSchema( + fields=all_fields, + description=description, + enable_dynamic_field=True, # Support dynamic fields ) + def _get_index_params(self): + """Get IndexParams in a version-compatible way""" + try: + # Try to use client's prepare_index_params method (most common) + if hasattr(self._client, "prepare_index_params"): + return self._client.prepare_index_params() + except Exception: + pass + + try: + # Try to import IndexParams from different possible locations + from pymilvus.client.prepare import IndexParams + + return IndexParams() + except ImportError: + pass + + try: + from pymilvus.client.types import IndexParams + + return IndexParams() + except ImportError: + pass + + try: + from pymilvus import IndexParams + + return IndexParams() + except ImportError: + pass + + # If all else fails, return None to use fallback method + return None + + def _create_vector_index_fallback(self): + """Fallback method to create vector index using direct API""" + try: + self._client.create_index( + collection_name=self.namespace, + field_name="vector", + index_params={ + "index_type": "HNSW", + "metric_type": "COSINE", + "params": {"M": 16, "efConstruction": 256}, + }, + ) + logger.debug("Created vector index using fallback method") + except Exception as e: + logger.warning(f"Failed to create vector index using fallback method: {e}") + + def _create_scalar_index_fallback(self, field_name: str, index_type: str): + """Fallback method to create scalar index using direct API""" + # Skip unsupported index types + if index_type == "SORTED": + logger.info( + f"Skipping SORTED index for {field_name} (not supported in this Milvus version)" + ) + return + + try: + self._client.create_index( + collection_name=self.namespace, + field_name=field_name, + index_params={"index_type": index_type}, + ) + logger.debug(f"Created {field_name} index using fallback method") + except Exception as e: + logger.info( + f"Could not create {field_name} index using fallback method: {e}" + ) + + def _create_indexes_after_collection(self): + """Create indexes after collection is created""" + try: + # Try to get IndexParams in a version-compatible way + IndexParamsClass = self._get_index_params() + + if IndexParamsClass is not None: + # Use IndexParams approach if available + try: + # Create vector index first (required for most operations) + vector_index = IndexParamsClass + vector_index.add_index( + field_name="vector", + index_type="HNSW", + metric_type="COSINE", + params={"M": 16, "efConstruction": 256}, + ) + self._client.create_index( + collection_name=self.namespace, index_params=vector_index + ) + logger.debug("Created vector index using IndexParams") + except Exception as e: + logger.debug(f"IndexParams method failed for vector index: {e}") + self._create_vector_index_fallback() + + # Create scalar indexes based on namespace + if "entities" in self.namespace.lower(): + # Create indexes for entity fields + try: + entity_name_index = self._get_index_params() + entity_name_index.add_index( + field_name="entity_name", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, + index_params=entity_name_index, + ) + except Exception as e: + logger.debug(f"IndexParams method failed for entity_name: {e}") + self._create_scalar_index_fallback("entity_name", "INVERTED") + + try: + entity_type_index = self._get_index_params() + entity_type_index.add_index( + field_name="entity_type", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, + index_params=entity_type_index, + ) + except Exception as e: + logger.debug(f"IndexParams method failed for entity_type: {e}") + self._create_scalar_index_fallback("entity_type", "INVERTED") + + elif "relationships" in self.namespace.lower(): + # Create indexes for relationship fields + try: + src_id_index = self._get_index_params() + src_id_index.add_index( + field_name="src_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=src_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for src_id: {e}") + self._create_scalar_index_fallback("src_id", "INVERTED") + + try: + tgt_id_index = self._get_index_params() + tgt_id_index.add_index( + field_name="tgt_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=tgt_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for tgt_id: {e}") + self._create_scalar_index_fallback("tgt_id", "INVERTED") + + elif "chunks" in self.namespace.lower(): + # Create indexes for chunk fields + try: + doc_id_index = self._get_index_params() + doc_id_index.add_index( + field_name="full_doc_id", index_type="INVERTED" + ) + self._client.create_index( + collection_name=self.namespace, index_params=doc_id_index + ) + except Exception as e: + logger.debug(f"IndexParams method failed for full_doc_id: {e}") + self._create_scalar_index_fallback("full_doc_id", "INVERTED") + + # No common indexes needed + + else: + # Fallback to direct API calls if IndexParams is not available + logger.info( + f"IndexParams not available, using fallback methods for {self.namespace}" + ) + + # Create vector index using fallback + self._create_vector_index_fallback() + + # Create scalar indexes using fallback + if "entities" in self.namespace.lower(): + self._create_scalar_index_fallback("entity_name", "INVERTED") + self._create_scalar_index_fallback("entity_type", "INVERTED") + elif "relationships" in self.namespace.lower(): + self._create_scalar_index_fallback("src_id", "INVERTED") + self._create_scalar_index_fallback("tgt_id", "INVERTED") + elif "chunks" in self.namespace.lower(): + self._create_scalar_index_fallback("full_doc_id", "INVERTED") + + logger.info(f"Created indexes for collection: {self.namespace}") + + except Exception as e: + logger.warning(f"Failed to create some indexes for {self.namespace}: {e}") + + def _get_required_fields_for_namespace(self) -> dict: + """Get required core field definitions for current namespace""" + + # Base fields (common to all types) + base_fields = { + "id": {"type": "VarChar", "is_primary": True}, + "vector": {"type": "FloatVector"}, + "created_at": {"type": "Int64"}, + } + + # Add specific fields based on namespace + if "entities" in self.namespace.lower(): + specific_fields = { + "entity_name": {"type": "VarChar"}, + "entity_type": {"type": "VarChar"}, + "file_path": {"type": "VarChar"}, + } + elif "relationships" in self.namespace.lower(): + specific_fields = { + "src_id": {"type": "VarChar"}, + "tgt_id": {"type": "VarChar"}, + "weight": {"type": "Double"}, + "file_path": {"type": "VarChar"}, + } + elif "chunks" in self.namespace.lower(): + specific_fields = { + "full_doc_id": {"type": "VarChar"}, + "file_path": {"type": "VarChar"}, + } + else: + specific_fields = { + "file_path": {"type": "VarChar"}, + } + + return {**base_fields, **specific_fields} + + def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool: + """Check compatibility of a single field""" + field_name = existing_field.get("name", "unknown") + existing_type = existing_field.get("type") + expected_type = expected_config.get("type") + + logger.debug( + f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" + ) + + # Convert DataType enum values to string names if needed + original_existing_type = existing_type + if hasattr(existing_type, "name"): + existing_type = existing_type.name + logger.debug( + f"Converted enum to name: {original_existing_type} -> {existing_type}" + ) + elif isinstance(existing_type, int): + # Map common Milvus internal type codes to type names for backward compatibility + type_mapping = { + 21: "VarChar", + 101: "FloatVector", + 5: "Int64", + 9: "Double", + } + mapped_type = type_mapping.get(existing_type, str(existing_type)) + logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}") + existing_type = mapped_type + + # Normalize type names for comparison + type_aliases = { + "VARCHAR": "VarChar", + "String": "VarChar", + "FLOAT_VECTOR": "FloatVector", + "INT64": "Int64", + "BigInt": "Int64", + "DOUBLE": "Double", + "Float": "Double", + } + + original_existing = existing_type + original_expected = expected_type + existing_type = type_aliases.get(existing_type, existing_type) + expected_type = type_aliases.get(expected_type, expected_type) + + if original_existing != existing_type or original_expected != expected_type: + logger.debug( + f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" + ) + + # Basic type compatibility check + type_compatible = existing_type == expected_type + logger.debug( + f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" + ) + + if not type_compatible: + logger.warning( + f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" + ) + return False + + # Primary key check - be more flexible about primary key detection + if expected_config.get("is_primary"): + # Check multiple possible field names for primary key status + is_primary = ( + existing_field.get("is_primary_key", False) + or existing_field.get("is_primary", False) + or existing_field.get("primary_key", False) + ) + logger.debug( + f"Primary key check for '{field_name}': expected=True, actual={is_primary}" + ) + logger.debug(f"Raw field data for '{field_name}': {existing_field}") + + # For ID field, be more lenient - if it's the ID field, assume it should be primary + if field_name == "id" and not is_primary: + logger.info( + f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" + ) + # Don't fail for ID field primary key mismatch + elif not is_primary: + logger.warning( + f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" + ) + return False + + logger.debug(f"Field '{field_name}' is compatible") + return True + + def _check_vector_dimension(self, collection_info: dict): + """Check vector dimension compatibility""" + current_dimension = self.embedding_func.embedding_dim + + # Find vector field dimension + for field in collection_info.get("fields", []): + if field.get("name") == "vector": + field_type = field.get("type") + if field_type in ["FloatVector", "FLOAT_VECTOR"]: + existing_dimension = field.get("params", {}).get("dim") + + if existing_dimension != current_dimension: + raise ValueError( + f"Vector dimension mismatch for collection '{self.namespace}': " + f"existing={existing_dimension}, current={current_dimension}" + ) + + logger.debug(f"Vector dimension check passed: {current_dimension}") + return + + # If no vector field found, this might be an old collection created with simple schema + logger.warning( + f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." + ) + logger.warning("Consider recreating the collection for optimal performance.") + return + + def _check_schema_compatibility(self, collection_info: dict): + """Check schema field compatibility""" + existing_fields = { + field["name"]: field for field in collection_info.get("fields", []) + } + + # Check if this is an old collection created with simple schema + has_vector_field = any( + field.get("name") == "vector" for field in collection_info.get("fields", []) + ) + + if not has_vector_field: + logger.warning( + f"Collection {self.namespace} appears to be created with old simple schema (no vector field)" + ) + logger.warning( + "This collection will work but may have suboptimal performance" + ) + logger.warning("Consider recreating the collection for optimal performance") + return + + # For collections with vector field, check basic compatibility + # Only check for critical incompatibilities, not missing optional fields + critical_fields = {"id": {"type": "VarChar", "is_primary": True}} + + incompatible_fields = [] + + for field_name, expected_config in critical_fields.items(): + if field_name in existing_fields: + existing_field = existing_fields[field_name] + if not self._is_field_compatible(existing_field, expected_config): + incompatible_fields.append( + f"{field_name}: expected {expected_config['type']}, " + f"got {existing_field.get('type')}" + ) + + if incompatible_fields: + raise ValueError( + f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}" + ) + + # Get all expected fields for informational purposes + expected_fields = self._get_required_fields_for_namespace() + missing_fields = [ + field for field in expected_fields if field not in existing_fields + ] + + if missing_fields: + logger.info( + f"Collection {self.namespace} missing optional fields: {missing_fields}" + ) + logger.info( + "These fields would be available in a newly created collection for better performance" + ) + + logger.debug(f"Schema compatibility check passed for {self.namespace}") + + def _validate_collection_compatibility(self): + """Validate existing collection's dimension and schema compatibility""" + try: + collection_info = self._client.describe_collection(self.namespace) + + # 1. Check vector dimension + self._check_vector_dimension(collection_info) + + # 2. Check schema compatibility + self._check_schema_compatibility(collection_info) + + logger.info(f"Collection {self.namespace} compatibility validation passed") + + except Exception as e: + logger.error( + f"Collection compatibility validation failed for {self.namespace}: {e}" + ) + raise + + def _create_collection_if_not_exist(self): + """Create collection if not exists and check existing collection compatibility""" + + try: + # First, list all collections to see what actually exists + try: + all_collections = self._client.list_collections() + logger.debug(f"All collections in database: {all_collections}") + except Exception as list_error: + logger.warning(f"Could not list collections: {list_error}") + all_collections = [] + + # Check if our specific collection exists + collection_exists = self._client.has_collection(self.namespace) + logger.info( + f"Collection '{self.namespace}' exists check: {collection_exists}" + ) + + if collection_exists: + # Double-check by trying to describe the collection + try: + self._client.describe_collection(self.namespace) + logger.info( + f"Collection '{self.namespace}' confirmed to exist, validating compatibility..." + ) + self._validate_collection_compatibility() + return + except Exception as describe_error: + logger.warning( + f"Collection '{self.namespace}' exists but cannot be described: {describe_error}" + ) + logger.info( + "Treating as if collection doesn't exist and creating new one..." + ) + # Fall through to creation logic + + # Collection doesn't exist, create new collection + logger.info(f"Creating new collection: {self.namespace}") + schema = self._create_schema_for_namespace() + + # Create collection with schema only first + self._client.create_collection( + collection_name=self.namespace, schema=schema + ) + + # Then create indexes + self._create_indexes_after_collection() + + logger.info(f"Successfully created Milvus collection: {self.namespace}") + + except Exception as e: + logger.error( + f"Error in _create_collection_if_not_exist for {self.namespace}: {e}" + ) + + # If there's any error, try to force create the collection + logger.info(f"Attempting to force create collection {self.namespace}...") + try: + # Try to drop the collection first if it exists in a bad state + try: + if self._client.has_collection(self.namespace): + logger.info( + f"Dropping potentially corrupted collection {self.namespace}" + ) + self._client.drop_collection(self.namespace) + except Exception as drop_error: + logger.warning( + f"Could not drop collection {self.namespace}: {drop_error}" + ) + + # Create fresh collection + schema = self._create_schema_for_namespace() + self._client.create_collection( + collection_name=self.namespace, schema=schema + ) + self._create_indexes_after_collection() + logger.info(f"Successfully force-created collection {self.namespace}") + + except Exception as create_error: + logger.error( + f"Failed to force-create collection {self.namespace}: {create_error}" + ) + raise + def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + # Ensure created_at is in meta_fields + if "created_at" not in self.meta_fields: + self.meta_fields.add("created_at") + self._client = MilvusClient( uri=os.environ.get( "MILVUS_URI", @@ -68,11 +661,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): ), ) self._max_batch_size = self.global_config["embedding_batch_num"] - MilvusVectorDBStorage.create_collection_if_not_exist( - self._client, - self.namespace, - dimension=self.embedding_func.embedding_dim, - ) + + # Create collection and check compatibility + self._create_collection_if_not_exist() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"Inserting {len(data)} to {self.namespace}") @@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func( [query], _priority=5 ) # higher priority for query + + # Include all meta_fields (created_at is now always included) + output_fields = list(self.meta_fields) + results = self._client.search( collection_name=self.namespace, data=embedding, limit=top_k, - output_fields=list(self.meta_fields) + ["created_at"], + output_fields=output_fields, search_params={ "metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}, }, ) - print(results) return [ { **dp["entity"], "id": dp["id"], "distance": dp["distance"], - # created_at is requested in output_fields, so it should be a top-level key in the result dict (dp) "created_at": dp.get("created_at"), } for dp in results[0] @@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage): The vector data if found, or None if not found """ try: + # Include all meta_fields (created_at is now always included) plus id + output_fields = list(self.meta_fields) + ["id"] + # Query Milvus for a specific ID result = self._client.query( collection_name=self.namespace, filter=f'id == "{id}"', - output_fields=list(self.meta_fields) + ["id", "created_at"], + output_fields=output_fields, ) if not result or len(result) == 0: return None - # Ensure the result contains created_at field - if "created_at" not in result[0]: - result[0]["created_at"] = None - return result[0] except Exception as e: logger.error(f"Error retrieving vector data for ID {id}: {e}") @@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return [] try: + # Include all meta_fields (created_at is now always included) plus id + output_fields = list(self.meta_fields) + ["id"] + # Prepare the ID filter expression id_list = '", "'.join(ids) filter_expr = f'id in ["{id_list}"]' @@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): result = self._client.query( collection_name=self.namespace, filter=filter_expr, - output_fields=list(self.meta_fields) + ["id", "created_at"], + output_fields=output_fields, ) - # Ensure each result contains created_at field - for item in result: - if "created_at" not in item: - item["created_at"] = None - return result or [] except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") @@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): self._client.drop_collection(self.namespace) # Recreate the collection - MilvusVectorDBStorage.create_collection_if_not_exist( - self._client, - self.namespace, - dimension=self.embedding_func.embedding_dim, - ) + self._create_collection_if_not_exist() logger.info( f"Process {os.getpid()} drop Milvus collection {self.namespace}"