Refactoring Milvus implementation
This commit is contained in:
parent
cb15883130
commit
6201c7fb56
1 changed files with 619 additions and 33 deletions
|
|
@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
||||||
pm.install("pymilvus")
|
pm.install("pymilvus")
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
from pymilvus import MilvusClient # type: ignore
|
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema # type: ignore
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read("config.ini", "utf-8")
|
config.read("config.ini", "utf-8")
|
||||||
|
|
@ -24,16 +24,605 @@ config.read("config.ini", "utf-8")
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
@staticmethod
|
def _create_schema_for_namespace(self) -> CollectionSchema:
|
||||||
def create_collection_if_not_exist(
|
"""Create schema based on the current instance's namespace"""
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
|
||||||
):
|
# Get vector dimension from embedding_func
|
||||||
if client.has_collection(collection_name):
|
dimension = self.embedding_func.embedding_dim
|
||||||
return
|
|
||||||
client.create_collection(
|
# Base fields (common to all collections)
|
||||||
collection_name, max_length=64, id_type="string", **kwargs
|
base_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True
|
||||||
|
),
|
||||||
|
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||||
|
FieldSchema(name="created_at", dtype=DataType.INT64),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Determine specific fields based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="entity_name",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=256,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="entity_type",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=64,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG entities vector storage"
|
||||||
|
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="src_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="tgt_id", dtype=DataType.VARCHAR, max_length=256, nullable=True
|
||||||
|
),
|
||||||
|
FieldSchema(name="weight", dtype=DataType.DOUBLE, nullable=True),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG relationships vector storage"
|
||||||
|
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="full_doc_id",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=64,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG chunks vector storage"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default generic schema (backward compatibility)
|
||||||
|
specific_fields = [
|
||||||
|
FieldSchema(
|
||||||
|
name="file_path",
|
||||||
|
dtype=DataType.VARCHAR,
|
||||||
|
max_length=512,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
description = "LightRAG generic vector storage"
|
||||||
|
|
||||||
|
# Merge all fields
|
||||||
|
all_fields = base_fields + specific_fields
|
||||||
|
|
||||||
|
return CollectionSchema(
|
||||||
|
fields=all_fields,
|
||||||
|
description=description,
|
||||||
|
enable_dynamic_field=True, # Support dynamic fields
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_index_params(self):
|
||||||
|
"""Get IndexParams in a version-compatible way"""
|
||||||
|
try:
|
||||||
|
# Try to use client's prepare_index_params method (most common)
|
||||||
|
if hasattr(self._client, "prepare_index_params"):
|
||||||
|
return self._client.prepare_index_params()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to import IndexParams from different possible locations
|
||||||
|
from pymilvus.client.prepare import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pymilvus.client.types import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pymilvus import IndexParams
|
||||||
|
|
||||||
|
return IndexParams()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If all else fails, return None to use fallback method
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_vector_index_fallback(self):
|
||||||
|
"""Fallback method to create vector index using direct API"""
|
||||||
|
try:
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
field_name="vector",
|
||||||
|
index_params={
|
||||||
|
"index_type": "HNSW",
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"params": {"M": 16, "efConstruction": 256},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.debug("Created vector index using fallback method")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create vector index using fallback method: {e}")
|
||||||
|
|
||||||
|
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
|
||||||
|
"""Fallback method to create scalar index using direct API"""
|
||||||
|
# Skip unsupported index types
|
||||||
|
if index_type == "SORTED":
|
||||||
|
logger.info(
|
||||||
|
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
field_name=field_name,
|
||||||
|
index_params={"index_type": index_type},
|
||||||
|
)
|
||||||
|
logger.debug(f"Created {field_name} index using fallback method")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
f"Could not create {field_name} index using fallback method: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_indexes_after_collection(self):
|
||||||
|
"""Create indexes after collection is created"""
|
||||||
|
try:
|
||||||
|
# Try to get IndexParams in a version-compatible way
|
||||||
|
IndexParamsClass = self._get_index_params()
|
||||||
|
|
||||||
|
if IndexParamsClass is not None:
|
||||||
|
# Use IndexParams approach if available
|
||||||
|
try:
|
||||||
|
# Create vector index first (required for most operations)
|
||||||
|
vector_index = IndexParamsClass
|
||||||
|
vector_index.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="HNSW",
|
||||||
|
metric_type="COSINE",
|
||||||
|
params={"M": 16, "efConstruction": 256},
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=vector_index
|
||||||
|
)
|
||||||
|
logger.debug("Created vector index using IndexParams")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for vector index: {e}")
|
||||||
|
self._create_vector_index_fallback()
|
||||||
|
|
||||||
|
# Create scalar indexes based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
# Create indexes for entity fields
|
||||||
|
try:
|
||||||
|
entity_name_index = self._get_index_params()
|
||||||
|
entity_name_index.add_index(
|
||||||
|
field_name="entity_name", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
index_params=entity_name_index,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for entity_name: {e}")
|
||||||
|
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||||
|
|
||||||
|
try:
|
||||||
|
entity_type_index = self._get_index_params()
|
||||||
|
entity_type_index.add_index(
|
||||||
|
field_name="entity_type", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace,
|
||||||
|
index_params=entity_type_index,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for entity_type: {e}")
|
||||||
|
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||||
|
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
# Create indexes for relationship fields
|
||||||
|
try:
|
||||||
|
src_id_index = self._get_index_params()
|
||||||
|
src_id_index.add_index(
|
||||||
|
field_name="src_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=src_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for src_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tgt_id_index = self._get_index_params()
|
||||||
|
tgt_id_index.add_index(
|
||||||
|
field_name="tgt_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=tgt_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for tgt_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||||
|
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
# Create indexes for chunk fields
|
||||||
|
try:
|
||||||
|
doc_id_index = self._get_index_params()
|
||||||
|
doc_id_index.add_index(
|
||||||
|
field_name="full_doc_id", index_type="INVERTED"
|
||||||
|
)
|
||||||
|
self._client.create_index(
|
||||||
|
collection_name=self.namespace, index_params=doc_id_index
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
|
||||||
|
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||||
|
|
||||||
|
# No common indexes needed
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fallback to direct API calls if IndexParams is not available
|
||||||
|
logger.info(
|
||||||
|
f"IndexParams not available, using fallback methods for {self.namespace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create vector index using fallback
|
||||||
|
self._create_vector_index_fallback()
|
||||||
|
|
||||||
|
# Create scalar indexes using fallback
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("entity_name", "INVERTED")
|
||||||
|
self._create_scalar_index_fallback("entity_type", "INVERTED")
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("src_id", "INVERTED")
|
||||||
|
self._create_scalar_index_fallback("tgt_id", "INVERTED")
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
|
||||||
|
|
||||||
|
logger.info(f"Created indexes for collection: {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
|
||||||
|
|
||||||
|
def _get_required_fields_for_namespace(self) -> dict:
|
||||||
|
"""Get required core field definitions for current namespace"""
|
||||||
|
|
||||||
|
# Base fields (common to all types)
|
||||||
|
base_fields = {
|
||||||
|
"id": {"type": "VarChar", "is_primary": True},
|
||||||
|
"vector": {"type": "FloatVector"},
|
||||||
|
"created_at": {"type": "Int64"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add specific fields based on namespace
|
||||||
|
if "entities" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"entity_name": {"type": "VarChar"},
|
||||||
|
"entity_type": {"type": "VarChar"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
elif "relationships" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"src_id": {"type": "VarChar"},
|
||||||
|
"tgt_id": {"type": "VarChar"},
|
||||||
|
"weight": {"type": "Double"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
elif "chunks" in self.namespace.lower():
|
||||||
|
specific_fields = {
|
||||||
|
"full_doc_id": {"type": "VarChar"},
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
specific_fields = {
|
||||||
|
"file_path": {"type": "VarChar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return {**base_fields, **specific_fields}
|
||||||
|
|
||||||
|
def _is_field_compatible(self, existing_field: dict, expected_config: dict) -> bool:
|
||||||
|
"""Check compatibility of a single field"""
|
||||||
|
field_name = existing_field.get("name", "unknown")
|
||||||
|
existing_type = existing_field.get("type")
|
||||||
|
expected_type = expected_config.get("type")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert DataType enum values to string names if needed
|
||||||
|
original_existing_type = existing_type
|
||||||
|
if hasattr(existing_type, "name"):
|
||||||
|
existing_type = existing_type.name
|
||||||
|
logger.debug(
|
||||||
|
f"Converted enum to name: {original_existing_type} -> {existing_type}"
|
||||||
|
)
|
||||||
|
elif isinstance(existing_type, int):
|
||||||
|
# Map common Milvus internal type codes to type names for backward compatibility
|
||||||
|
type_mapping = {
|
||||||
|
21: "VarChar",
|
||||||
|
101: "FloatVector",
|
||||||
|
5: "Int64",
|
||||||
|
9: "Double",
|
||||||
|
}
|
||||||
|
mapped_type = type_mapping.get(existing_type, str(existing_type))
|
||||||
|
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
|
||||||
|
existing_type = mapped_type
|
||||||
|
|
||||||
|
# Normalize type names for comparison
|
||||||
|
type_aliases = {
|
||||||
|
"VARCHAR": "VarChar",
|
||||||
|
"String": "VarChar",
|
||||||
|
"FLOAT_VECTOR": "FloatVector",
|
||||||
|
"INT64": "Int64",
|
||||||
|
"BigInt": "Int64",
|
||||||
|
"DOUBLE": "Double",
|
||||||
|
"Float": "Double",
|
||||||
|
}
|
||||||
|
|
||||||
|
original_existing = existing_type
|
||||||
|
original_expected = expected_type
|
||||||
|
existing_type = type_aliases.get(existing_type, existing_type)
|
||||||
|
expected_type = type_aliases.get(expected_type, expected_type)
|
||||||
|
|
||||||
|
if original_existing != existing_type or original_expected != expected_type:
|
||||||
|
logger.debug(
|
||||||
|
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Basic type compatibility check
|
||||||
|
type_compatible = existing_type == expected_type
|
||||||
|
logger.debug(
|
||||||
|
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not type_compatible:
|
||||||
|
logger.warning(
|
||||||
|
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Primary key check - be more flexible about primary key detection
|
||||||
|
if expected_config.get("is_primary"):
|
||||||
|
# Check multiple possible field names for primary key status
|
||||||
|
is_primary = (
|
||||||
|
existing_field.get("is_primary_key", False)
|
||||||
|
or existing_field.get("is_primary", False)
|
||||||
|
or existing_field.get("primary_key", False)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
|
||||||
|
)
|
||||||
|
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
|
||||||
|
|
||||||
|
# For ID field, be more lenient - if it's the ID field, assume it should be primary
|
||||||
|
if field_name == "id" and not is_primary:
|
||||||
|
logger.info(
|
||||||
|
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
|
||||||
|
)
|
||||||
|
# Don't fail for ID field primary key mismatch
|
||||||
|
elif not is_primary:
|
||||||
|
logger.warning(
|
||||||
|
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.debug(f"Field '{field_name}' is compatible")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_vector_dimension(self, collection_info: dict):
|
||||||
|
"""Check vector dimension compatibility"""
|
||||||
|
current_dimension = self.embedding_func.embedding_dim
|
||||||
|
|
||||||
|
# Find vector field dimension
|
||||||
|
for field in collection_info.get("fields", []):
|
||||||
|
if field.get("name") == "vector":
|
||||||
|
field_type = field.get("type")
|
||||||
|
if field_type in ["FloatVector", "FLOAT_VECTOR"]:
|
||||||
|
existing_dimension = field.get("params", {}).get("dim")
|
||||||
|
|
||||||
|
if existing_dimension != current_dimension:
|
||||||
|
raise ValueError(
|
||||||
|
f"Vector dimension mismatch for collection '{self.namespace}': "
|
||||||
|
f"existing={existing_dimension}, current={current_dimension}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Vector dimension check passed: {current_dimension}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no vector field found, this might be an old collection created with simple schema
|
||||||
|
logger.warning(
|
||||||
|
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
|
||||||
|
)
|
||||||
|
logger.warning("Consider recreating the collection for optimal performance.")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _check_schema_compatibility(self, collection_info: dict):
|
||||||
|
"""Check schema field compatibility"""
|
||||||
|
existing_fields = {
|
||||||
|
field["name"]: field for field in collection_info.get("fields", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if this is an old collection created with simple schema
|
||||||
|
has_vector_field = any(
|
||||||
|
field.get("name") == "vector" for field in collection_info.get("fields", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_vector_field:
|
||||||
|
logger.warning(
|
||||||
|
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"This collection will work but may have suboptimal performance"
|
||||||
|
)
|
||||||
|
logger.warning("Consider recreating the collection for optimal performance")
|
||||||
|
return
|
||||||
|
|
||||||
|
# For collections with vector field, check basic compatibility
|
||||||
|
# Only check for critical incompatibilities, not missing optional fields
|
||||||
|
critical_fields = {"id": {"type": "VarChar", "is_primary": True}}
|
||||||
|
|
||||||
|
incompatible_fields = []
|
||||||
|
|
||||||
|
for field_name, expected_config in critical_fields.items():
|
||||||
|
if field_name in existing_fields:
|
||||||
|
existing_field = existing_fields[field_name]
|
||||||
|
if not self._is_field_compatible(existing_field, expected_config):
|
||||||
|
incompatible_fields.append(
|
||||||
|
f"{field_name}: expected {expected_config['type']}, "
|
||||||
|
f"got {existing_field.get('type')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if incompatible_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get all expected fields for informational purposes
|
||||||
|
expected_fields = self._get_required_fields_for_namespace()
|
||||||
|
missing_fields = [
|
||||||
|
field for field in expected_fields if field not in existing_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
logger.info(
|
||||||
|
f"Collection {self.namespace} missing optional fields: {missing_fields}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"These fields would be available in a newly created collection for better performance"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Schema compatibility check passed for {self.namespace}")
|
||||||
|
|
||||||
|
def _validate_collection_compatibility(self):
|
||||||
|
"""Validate existing collection's dimension and schema compatibility"""
|
||||||
|
try:
|
||||||
|
collection_info = self._client.describe_collection(self.namespace)
|
||||||
|
|
||||||
|
# 1. Check vector dimension
|
||||||
|
self._check_vector_dimension(collection_info)
|
||||||
|
|
||||||
|
# 2. Check schema compatibility
|
||||||
|
self._check_schema_compatibility(collection_info)
|
||||||
|
|
||||||
|
logger.info(f"Collection {self.namespace} compatibility validation passed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Collection compatibility validation failed for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_collection_if_not_exist(self):
|
||||||
|
"""Create collection if not exists and check existing collection compatibility"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First, list all collections to see what actually exists
|
||||||
|
try:
|
||||||
|
all_collections = self._client.list_collections()
|
||||||
|
logger.debug(f"All collections in database: {all_collections}")
|
||||||
|
except Exception as list_error:
|
||||||
|
logger.warning(f"Could not list collections: {list_error}")
|
||||||
|
all_collections = []
|
||||||
|
|
||||||
|
# Check if our specific collection exists
|
||||||
|
collection_exists = self._client.has_collection(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Collection '{self.namespace}' exists check: {collection_exists}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if collection_exists:
|
||||||
|
# Double-check by trying to describe the collection
|
||||||
|
try:
|
||||||
|
self._client.describe_collection(self.namespace)
|
||||||
|
logger.info(
|
||||||
|
f"Collection '{self.namespace}' confirmed to exist, validating compatibility..."
|
||||||
|
)
|
||||||
|
self._validate_collection_compatibility()
|
||||||
|
return
|
||||||
|
except Exception as describe_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Treating as if collection doesn't exist and creating new one..."
|
||||||
|
)
|
||||||
|
# Fall through to creation logic
|
||||||
|
|
||||||
|
# Collection doesn't exist, create new collection
|
||||||
|
logger.info(f"Creating new collection: {self.namespace}")
|
||||||
|
schema = self._create_schema_for_namespace()
|
||||||
|
|
||||||
|
# Create collection with schema only first
|
||||||
|
self._client.create_collection(
|
||||||
|
collection_name=self.namespace, schema=schema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then create indexes
|
||||||
|
self._create_indexes_after_collection()
|
||||||
|
|
||||||
|
logger.info(f"Successfully created Milvus collection: {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there's any error, try to force create the collection
|
||||||
|
logger.info(f"Attempting to force create collection {self.namespace}...")
|
||||||
|
try:
|
||||||
|
# Try to drop the collection first if it exists in a bad state
|
||||||
|
try:
|
||||||
|
if self._client.has_collection(self.namespace):
|
||||||
|
logger.info(
|
||||||
|
f"Dropping potentially corrupted collection {self.namespace}"
|
||||||
|
)
|
||||||
|
self._client.drop_collection(self.namespace)
|
||||||
|
except Exception as drop_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not drop collection {self.namespace}: {drop_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fresh collection
|
||||||
|
schema = self._create_schema_for_namespace()
|
||||||
|
self._client.create_collection(
|
||||||
|
collection_name=self.namespace, schema=schema
|
||||||
|
)
|
||||||
|
self._create_indexes_after_collection()
|
||||||
|
logger.info(f"Successfully force-created collection {self.namespace}")
|
||||||
|
|
||||||
|
except Exception as create_error:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to force-create collection {self.namespace}: {create_error}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||||
|
|
@ -43,6 +632,10 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
# Ensure created_at is in meta_fields
|
||||||
|
if "created_at" not in self.meta_fields:
|
||||||
|
self.meta_fields.add("created_at")
|
||||||
|
|
||||||
self._client = MilvusClient(
|
self._client = MilvusClient(
|
||||||
uri=os.environ.get(
|
uri=os.environ.get(
|
||||||
"MILVUS_URI",
|
"MILVUS_URI",
|
||||||
|
|
@ -68,11 +661,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
|
||||||
self._client,
|
# Create collection and check compatibility
|
||||||
self.namespace,
|
self._create_collection_if_not_exist()
|
||||||
dimension=self.embedding_func.embedding_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"Inserting {len(data)} to {self.namespace}")
|
||||||
|
|
@ -112,23 +703,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
embedding = await self.embedding_func(
|
embedding = await self.embedding_func(
|
||||||
[query], _priority=5
|
[query], _priority=5
|
||||||
) # higher priority for query
|
) # higher priority for query
|
||||||
|
|
||||||
|
# Include all meta_fields (created_at is now always included)
|
||||||
|
output_fields = list(self.meta_fields)
|
||||||
|
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
data=embedding,
|
data=embedding,
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
output_fields=list(self.meta_fields) + ["created_at"],
|
output_fields=output_fields,
|
||||||
search_params={
|
search_params={
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
"params": {"radius": self.cosine_better_than_threshold},
|
"params": {"radius": self.cosine_better_than_threshold},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(results)
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
**dp["entity"],
|
**dp["entity"],
|
||||||
"id": dp["id"],
|
"id": dp["id"],
|
||||||
"distance": dp["distance"],
|
"distance": dp["distance"],
|
||||||
# created_at is requested in output_fields, so it should be a top-level key in the result dict (dp)
|
|
||||||
"created_at": dp.get("created_at"),
|
"created_at": dp.get("created_at"),
|
||||||
}
|
}
|
||||||
for dp in results[0]
|
for dp in results[0]
|
||||||
|
|
@ -232,20 +825,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
The vector data if found, or None if not found
|
The vector data if found, or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Include all meta_fields (created_at is now always included) plus id
|
||||||
|
output_fields = list(self.meta_fields) + ["id"]
|
||||||
|
|
||||||
# Query Milvus for a specific ID
|
# Query Milvus for a specific ID
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
filter=f'id == "{id}"',
|
filter=f'id == "{id}"',
|
||||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result or len(result) == 0:
|
if not result or len(result) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Ensure the result contains created_at field
|
|
||||||
if "created_at" not in result[0]:
|
|
||||||
result[0]["created_at"] = None
|
|
||||||
|
|
||||||
return result[0]
|
return result[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
||||||
|
|
@ -264,6 +856,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Include all meta_fields (created_at is now always included) plus id
|
||||||
|
output_fields = list(self.meta_fields) + ["id"]
|
||||||
|
|
||||||
# Prepare the ID filter expression
|
# Prepare the ID filter expression
|
||||||
id_list = '", "'.join(ids)
|
id_list = '", "'.join(ids)
|
||||||
filter_expr = f'id in ["{id_list}"]'
|
filter_expr = f'id in ["{id_list}"]'
|
||||||
|
|
@ -272,14 +867,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
filter=filter_expr,
|
filter=filter_expr,
|
||||||
output_fields=list(self.meta_fields) + ["id", "created_at"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure each result contains created_at field
|
|
||||||
for item in result:
|
|
||||||
if "created_at" not in item:
|
|
||||||
item["created_at"] = None
|
|
||||||
|
|
||||||
return result or []
|
return result or []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||||
|
|
@ -301,11 +891,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
self._client.drop_collection(self.namespace)
|
self._client.drop_collection(self.namespace)
|
||||||
|
|
||||||
# Recreate the collection
|
# Recreate the collection
|
||||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
self._create_collection_if_not_exist()
|
||||||
self._client,
|
|
||||||
self.namespace,
|
|
||||||
dimension=self.embedding_func.embedding_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue