Refac: Add workspace infomation to all logger output for all storage type

This commit is contained in:
yangdx 2025-08-12 01:19:09 +08:00
parent 44204abef7
commit 095e0cbfa2
11 changed files with 734 additions and 417 deletions

View file

@ -252,6 +252,7 @@ POSTGRES_IVFFLAT_LISTS=100
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j
NEO4J_PASSWORD='your_password'
# NEO4J_DATABASE=chunk_entity_relation
NEO4J_MAX_CONNECTION_POOL_SIZE=100
NEO4J_CONNECTION_TIMEOUT=30
NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30

View file

@ -33,19 +33,18 @@ class JsonDocStatusStorage(DocStatusStorage):
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(
workspace_dir, f"kv_store_{self.namespace}.json"
)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self._file_name = os.path.join(
working_dir, f"kv_store_{self.namespace}.json"
)
self.final_namespace = self.namespace
self.workspace = "_"
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None
self._storage_lock = None
self.storage_updated = None
self.final_namespace = f"{self.workspace}_{self.namespace}"
async def initialize(self):
"""Initialize storage data"""
@ -60,7 +59,7 @@ class JsonDocStatusStorage(DocStatusStorage):
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(
f"Process {os.getpid()} doc status load {self.final_namespace} with {len(loaded_data)} records"
f"[{self.workspace}] Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
)
async def filter_keys(self, keys: set[str]) -> set[str]:
@ -108,7 +107,9 @@ class JsonDocStatusStorage(DocStatusStorage):
data["error_msg"] = None
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue
return result
@ -135,7 +136,9 @@ class JsonDocStatusStorage(DocStatusStorage):
data["error_msg"] = None
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue
return result
@ -146,7 +149,7 @@ class JsonDocStatusStorage(DocStatusStorage):
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
logger.debug(
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.final_namespace}"
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
)
write_json(data_dict, self._file_name)
await clear_all_update_flags(self.final_namespace)
@ -159,7 +162,9 @@ class JsonDocStatusStorage(DocStatusStorage):
"""
if not data:
return
logger.debug(f"Inserting {len(data)} records to {self.final_namespace}")
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._storage_lock:
# Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items():
@ -242,7 +247,9 @@ class JsonDocStatusStorage(DocStatusStorage):
all_docs.append((doc_id, doc_status))
except KeyError as e:
logger.error(f"Error processing document {doc_id}: {e}")
logger.error(
f"[{self.workspace}] Error processing document {doc_id}: {e}"
)
continue
# Sort documents
@ -321,8 +328,10 @@ class JsonDocStatusStorage(DocStatusStorage):
await set_all_update_flags(self.final_namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.final_namespace}")
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.final_namespace}: {e}")
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View file

@ -29,19 +29,19 @@ class JsonKVStorage(BaseKVStorage):
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(
workspace_dir, f"kv_store_{self.namespace}.json"
)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self._file_name = os.path.join(
working_dir, f"kv_store_{self.namespace}.json"
)
workspace_dir = working_dir
self.final_namespace = self.namespace
self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None
self._storage_lock = None
self.storage_updated = None
self.final_namespace = f"{self.workspace}_{self.namespace}"
async def initialize(self):
"""Initialize storage data"""
@ -64,7 +64,7 @@ class JsonKVStorage(BaseKVStorage):
data_count = len(loaded_data)
logger.info(
f"Process {os.getpid()} KV load {self.final_namespace} with {data_count} records"
f"[{self.workspace}] Process {os.getpid()} KV load {self.namespace} with {data_count} records"
)
async def index_done_callback(self) -> None:
@ -78,7 +78,7 @@ class JsonKVStorage(BaseKVStorage):
data_count = len(data_dict)
logger.debug(
f"Process {os.getpid()} KV writting {data_count} records to {self.final_namespace}"
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
)
write_json(data_dict, self._file_name)
await clear_all_update_flags(self.final_namespace)
@ -151,12 +151,14 @@ class JsonKVStorage(BaseKVStorage):
current_time = int(time.time()) # Get current Unix timestamp
logger.debug(f"Inserting {len(data)} records to {self.final_namespace}")
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._storage_lock:
# Add timestamps to data based on whether key exists
for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
@ -215,10 +217,12 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.final_namespace)
await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.final_namespace}")
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.final_namespace}: {e}")
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
@ -263,7 +267,7 @@ class JsonKVStorage(BaseKVStorage):
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure"
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately
write_json(migrated_data, self._file_name)

View file

@ -34,21 +34,25 @@ config.read("config.ini", "utf-8")
@dataclass
class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None):
# Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base'
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip():
workspace = memgraph_workspace
if not workspace or not str(workspace).strip():
workspace = "base"
super().__init__(
namespace=namespace,
workspace=workspace or "",
workspace=workspace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
workspace = getattr(self, "workspace", None)
return workspace if workspace else "base"
"""Return workspace label (guaranteed non-empty during initialization)"""
return self.workspace
async def initialize(self):
URI = os.environ.get(
@ -79,17 +83,19 @@ class MemgraphStorage(BaseGraphStorage):
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph."
f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
)
await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}")
logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
logger.error(
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
)
raise
async def finalize(self):
@ -134,7 +140,9 @@ class MemgraphStorage(BaseGraphStorage):
single_result["node_exists"] if single_result is not None else False
)
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
raise
@ -177,7 +185,7 @@ class MemgraphStorage(BaseGraphStorage):
)
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
raise
@ -215,7 +223,7 @@ class MemgraphStorage(BaseGraphStorage):
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
@ -232,7 +240,9 @@ class MemgraphStorage(BaseGraphStorage):
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
)
raise
async def node_degree(self, node_id: str) -> int:
@ -268,7 +278,9 @@ class MemgraphStorage(BaseGraphStorage):
record = await result.single()
if not record:
logger.warning(f"No node found with label '{node_id}'")
logger.warning(
f"[{self.workspace}] No node found with label '{node_id}'"
)
return 0
degree = record["degree"]
@ -276,7 +288,9 @@ class MemgraphStorage(BaseGraphStorage):
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
)
raise
async def get_all_labels(self) -> list[str]:
@ -310,7 +324,7 @@ class MemgraphStorage(BaseGraphStorage):
await result.consume()
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
raise
@ -370,12 +384,14 @@ class MemgraphStorage(BaseGraphStorage):
return edges
except Exception as e:
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
)
raise
async def get_edge(
@ -424,13 +440,13 @@ class MemgraphStorage(BaseGraphStorage):
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
)
return edge_result
return None
except Exception as e:
logger.error(
f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
raise
@ -463,7 +479,7 @@ class MemgraphStorage(BaseGraphStorage):
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
f"[{self.workspace}] Attempting node upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
@ -504,20 +520,24 @@ class MemgraphStorage(BaseGraphStorage):
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
f"[{self.workspace}] Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
f"[{self.workspace}] Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}")
logger.error(
f"[{self.workspace}] Non-transient error during node upsert: {str(e)}"
)
raise
except Exception as e:
logger.error(f"Unexpected error during node upsert: {str(e)}")
logger.error(
f"[{self.workspace}] Unexpected error during node upsert: {str(e)}"
)
raise
async def upsert_edge(
@ -552,7 +572,7 @@ class MemgraphStorage(BaseGraphStorage):
for attempt in range(max_retries):
try:
logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
f"[{self.workspace}] Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
)
async with self._driver.session(database=self._DATABASE) as session:
@ -602,20 +622,24 @@ class MemgraphStorage(BaseGraphStorage):
initial_wait_time * (backoff_factor**attempt) + jitter
)
logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
f"[{self.workspace}] Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
)
await asyncio.sleep(wait_time)
else:
logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
f"[{self.workspace}] Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
)
raise
else:
# Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}")
logger.error(
f"[{self.workspace}] Non-transient error during edge upsert: {str(e)}"
)
raise
except Exception as e:
logger.error(f"Unexpected error during edge upsert: {str(e)}")
logger.error(
f"[{self.workspace}] Unexpected error during edge upsert: {str(e)}"
)
raise
async def delete_node(self, node_id: str) -> None:
@ -639,14 +663,14 @@ class MemgraphStorage(BaseGraphStorage):
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label {node_id}")
logger.debug(f"[{self.workspace}] Deleted node with label {node_id}")
await result.consume()
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
raise
async def remove_nodes(self, nodes: list[str]):
@ -686,14 +710,16 @@ class MemgraphStorage(BaseGraphStorage):
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise
async def drop(self) -> dict[str, str]:
@ -720,12 +746,12 @@ class MemgraphStorage(BaseGraphStorage):
result = await session.run(query)
await result.consume()
logger.info(
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
@ -945,14 +971,16 @@ class MemgraphStorage(BaseGraphStorage):
# If no record found, return empty KnowledgeGraph
if not record:
logger.debug(f"No nodes found for entity_id: {node_label}")
logger.debug(
f"[{self.workspace}] No nodes found for entity_id: {node_label}"
)
return result
# Check if the result was truncated
if record.get("is_truncated"):
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes"
)
finally:
@ -990,11 +1018,13 @@ class MemgraphStorage(BaseGraphStorage):
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except Exception as e:
logger.warning(f"Memgraph error during subgraph query: {str(e)}")
logger.warning(
f"[{self.workspace}] Memgraph error during subgraph query: {str(e)}"
)
return result

View file

@ -37,7 +37,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
]
# Determine specific fields based on namespace
if "entities" in self.namespace.lower():
if self.namespace.endswith("entities"):
specific_fields = [
FieldSchema(
name="entity_name",
@ -54,7 +54,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
]
description = "LightRAG entities vector storage"
elif "relationships" in self.namespace.lower():
elif self.namespace.endswith("relationships"):
specific_fields = [
FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
@ -71,7 +71,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
]
description = "LightRAG relationships vector storage"
elif "chunks" in self.namespace.lower():
elif self.namespace.endswith("chunks"):
specific_fields = [
FieldSchema(
name="full_doc_id",
@ -147,7 +147,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""Fallback method to create vector index using direct API"""
try:
self._client.create_index(
collection_name=self.namespace,
collection_name=self.final_namespace,
field_name="vector",
index_params={
"index_type": "HNSW",
@ -155,29 +155,35 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"params": {"M": 16, "efConstruction": 256},
},
)
logger.debug("Created vector index using fallback method")
logger.debug(
f"[{self.workspace}] Created vector index using fallback method"
)
except Exception as e:
logger.warning(f"Failed to create vector index using fallback method: {e}")
logger.warning(
f"[{self.workspace}] Failed to create vector index using fallback method: {e}"
)
def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API"""
# Skip unsupported index types
if index_type == "SORTED":
logger.info(
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)"
f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)"
)
return
try:
self._client.create_index(
collection_name=self.namespace,
collection_name=self.final_namespace,
field_name=field_name,
index_params={"index_type": index_type},
)
logger.debug(f"Created {field_name} index using fallback method")
logger.debug(
f"[{self.workspace}] Created {field_name} index using fallback method"
)
except Exception as e:
logger.info(
f"Could not create {field_name} index using fallback method: {e}"
f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}"
)
def _create_indexes_after_collection(self):
@ -198,15 +204,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
params={"M": 16, "efConstruction": 256},
)
self._client.create_index(
collection_name=self.namespace, index_params=vector_index
collection_name=self.final_namespace, index_params=vector_index
)
logger.debug(
f"[{self.workspace}] Created vector index using IndexParams"
)
logger.debug("Created vector index using IndexParams")
except Exception as e:
logger.debug(f"IndexParams method failed for vector index: {e}")
logger.debug(
f"[{self.workspace}] IndexParams method failed for vector index: {e}"
)
self._create_vector_index_fallback()
# Create scalar indexes based on namespace
if "entities" in self.namespace.lower():
if self.namespace.endswith("entities"):
# Create indexes for entity fields
try:
entity_name_index = self._get_index_params()
@ -214,14 +224,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="entity_name", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace,
collection_name=self.final_namespace,
index_params=entity_name_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for entity_name: {e}")
logger.debug(
f"[{self.workspace}] IndexParams method failed for entity_name: {e}"
)
self._create_scalar_index_fallback("entity_name", "INVERTED")
elif "relationships" in self.namespace.lower():
elif self.namespace.endswith("relationships"):
# Create indexes for relationship fields
try:
src_id_index = self._get_index_params()
@ -229,10 +241,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="src_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=src_id_index
collection_name=self.final_namespace,
index_params=src_id_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for src_id: {e}")
logger.debug(
f"[{self.workspace}] IndexParams method failed for src_id: {e}"
)
self._create_scalar_index_fallback("src_id", "INVERTED")
try:
@ -241,13 +256,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="tgt_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=tgt_id_index
collection_name=self.final_namespace,
index_params=tgt_id_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for tgt_id: {e}")
logger.debug(
f"[{self.workspace}] IndexParams method failed for tgt_id: {e}"
)
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
elif self.namespace.endswith("chunks"):
# Create indexes for chunk fields
try:
doc_id_index = self._get_index_params()
@ -255,10 +273,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="full_doc_id", index_type="INVERTED"
)
self._client.create_index(
collection_name=self.namespace, index_params=doc_id_index
collection_name=self.final_namespace,
index_params=doc_id_index,
)
except Exception as e:
logger.debug(f"IndexParams method failed for full_doc_id: {e}")
logger.debug(
f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}"
)
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed
@ -266,25 +287,29 @@ class MilvusVectorDBStorage(BaseVectorStorage):
else:
# Fallback to direct API calls if IndexParams is not available
logger.info(
f"IndexParams not available, using fallback methods for {self.namespace}"
f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}"
)
# Create vector index using fallback
self._create_vector_index_fallback()
# Create scalar indexes using fallback
if "entities" in self.namespace.lower():
if self.namespace.endswith("entities"):
self._create_scalar_index_fallback("entity_name", "INVERTED")
elif "relationships" in self.namespace.lower():
elif self.namespace.endswith("relationships"):
self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower():
elif self.namespace.endswith("chunks"):
self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(f"Created indexes for collection: {self.namespace}")
logger.info(
f"[{self.workspace}] Created indexes for collection: {self.namespace}"
)
except Exception as e:
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}")
logger.warning(
f"[{self.workspace}] Failed to create some indexes for {self.namespace}: {e}"
)
def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace"""
@ -297,18 +322,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
}
# Add specific fields based on namespace
if "entities" in self.namespace.lower():
if self.namespace.endswith("entities"):
specific_fields = {
"entity_name": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif "relationships" in self.namespace.lower():
elif self.namespace.endswith("relationships"):
specific_fields = {
"src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
}
elif "chunks" in self.namespace.lower():
elif self.namespace.endswith("chunks"):
specific_fields = {
"full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"},
@ -327,7 +352,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
expected_type = expected_config.get("type")
logger.debug(
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
)
# Convert DataType enum values to string names if needed
@ -335,7 +360,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if hasattr(existing_type, "name"):
existing_type = existing_type.name
logger.debug(
f"Converted enum to name: {original_existing_type} -> {existing_type}"
f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}"
)
elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility
@ -346,7 +371,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
9: "Double",
}
mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}")
logger.debug(
f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}"
)
existing_type = mapped_type
# Normalize type names for comparison
@ -367,18 +394,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if original_existing != existing_type or original_expected != expected_type:
logger.debug(
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
)
# Basic type compatibility check
type_compatible = existing_type == expected_type
logger.debug(
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
)
if not type_compatible:
logger.warning(
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
)
return False
@ -391,23 +418,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
or existing_field.get("primary_key", False)
)
logger.debug(
f"Primary key check for '{field_name}': expected=True, actual={is_primary}"
f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(
f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}"
)
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
# For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary:
logger.info(
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
)
# Don't fail for ID field primary key mismatch
elif not is_primary:
logger.warning(
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
)
return False
logger.debug(f"Field '{field_name}' is compatible")
logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible")
return True
def _check_vector_dimension(self, collection_info: dict):
@ -434,18 +463,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if existing_dimension != current_dimension:
raise ValueError(
f"Vector dimension mismatch for collection '{self.namespace}': "
f"Vector dimension mismatch for collection '{self.final_namespace}': "
f"existing={existing_dimension}, current={current_dimension}"
)
logger.debug(f"Vector dimension check passed: {current_dimension}")
logger.debug(
f"[{self.workspace}] Vector dimension check passed: {current_dimension}"
)
return
# If no vector field found, this might be an old collection created with simple schema
logger.warning(
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance."
)
logger.warning("Consider recreating the collection for optimal performance.")
return
def _check_schema_compatibility(self, collection_info: dict):
@ -461,12 +494,14 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if not has_vector_field:
logger.warning(
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)"
f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)"
)
logger.warning(
"This collection will work but may have suboptimal performance"
f"[{self.workspace}] This collection will work but may have suboptimal performance"
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance"
)
logger.warning("Consider recreating the collection for optimal performance")
return
# For collections with vector field, check basic compatibility
@ -486,7 +521,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if incompatible_fields:
raise ValueError(
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}"
f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}"
)
# Get all expected fields for informational purposes
@ -497,18 +532,20 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if missing_fields:
logger.info(
f"Collection {self.namespace} missing optional fields: {missing_fields}"
f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}"
)
logger.info(
"These fields would be available in a newly created collection for better performance"
)
logger.debug(f"Schema compatibility check passed for {self.namespace}")
logger.debug(
f"[{self.workspace}] Schema compatibility check passed for {self.namespace}"
)
def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility"""
try:
collection_info = self._client.describe_collection(self.namespace)
collection_info = self._client.describe_collection(self.final_namespace)
# 1. Check vector dimension
self._check_vector_dimension(collection_info)
@ -517,12 +554,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._check_schema_compatibility(collection_info)
logger.info(
f"VectorDB Collection '{self.namespace}' compatibility validation passed"
f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed"
)
except Exception as e:
logger.error(
f"Collection compatibility validation failed for {self.namespace}: {e}"
f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}"
)
raise
@ -530,17 +567,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""Ensure the collection is loaded into memory for search operations"""
try:
# Check if collection exists first
if not self._client.has_collection(self.namespace):
logger.error(f"Collection {self.namespace} does not exist")
raise ValueError(f"Collection {self.namespace} does not exist")
if not self._client.has_collection(self.final_namespace):
logger.error(
f"[{self.workspace}] Collection {self.namespace} does not exist"
)
raise ValueError(f"Collection {self.final_namespace} does not exist")
# Load the collection if it's not already loaded
# In Milvus, collections need to be loaded before they can be searched
self._client.load_collection(self.namespace)
# logger.debug(f"Collection {self.namespace} loaded successfully")
self._client.load_collection(self.final_namespace)
# logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully")
except Exception as e:
logger.error(f"Failed to load collection {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Failed to load collection {self.namespace}: {e}"
)
raise
def _create_collection_if_not_exist(self):
@ -550,41 +591,45 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# First, list all collections to see what actually exists
try:
all_collections = self._client.list_collections()
logger.debug(f"All collections in database: {all_collections}")
logger.debug(
f"[{self.workspace}] All collections in database: {all_collections}"
)
except Exception as list_error:
logger.warning(f"Could not list collections: {list_error}")
logger.warning(
f"[{self.workspace}] Could not list collections: {list_error}"
)
all_collections = []
# Check if our specific collection exists
collection_exists = self._client.has_collection(self.namespace)
collection_exists = self._client.has_collection(self.final_namespace)
logger.info(
f"VectorDB collection '{self.namespace}' exists check: {collection_exists}"
f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}"
)
if collection_exists:
# Double-check by trying to describe the collection
try:
self._client.describe_collection(self.namespace)
self._client.describe_collection(self.final_namespace)
self._validate_collection_compatibility()
# Ensure the collection is loaded after validation
self._ensure_collection_loaded()
return
except Exception as describe_error:
logger.warning(
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}"
f"[{self.workspace}] Collection '{self.namespace}' exists but cannot be described: {describe_error}"
)
logger.info(
"Treating as if collection doesn't exist and creating new one..."
f"[{self.workspace}] Treating as if collection doesn't exist and creating new one..."
)
# Fall through to creation logic
# Collection doesn't exist, create new collection
logger.info(f"Creating new collection: {self.namespace}")
logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}")
schema = self._create_schema_for_namespace()
# Create collection with schema only first
self._client.create_collection(
collection_name=self.namespace, schema=schema
collection_name=self.final_namespace, schema=schema
)
# Then create indexes
@ -593,43 +638,49 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Load the newly created collection
self._ensure_collection_loaded()
logger.info(f"Successfully created Milvus collection: {self.namespace}")
logger.info(
f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}"
)
except Exception as e:
logger.error(
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}"
f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}"
)
# If there's any error, try to force create the collection
logger.info(f"Attempting to force create collection {self.namespace}...")
logger.info(
f"[{self.workspace}] Attempting to force create collection {self.namespace}..."
)
try:
# Try to drop the collection first if it exists in a bad state
try:
if self._client.has_collection(self.namespace):
if self._client.has_collection(self.final_namespace):
logger.info(
f"Dropping potentially corrupted collection {self.namespace}"
f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}"
)
self._client.drop_collection(self.namespace)
self._client.drop_collection(self.final_namespace)
except Exception as drop_error:
logger.warning(
f"Could not drop collection {self.namespace}: {drop_error}"
f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}"
)
# Create fresh collection
schema = self._create_schema_for_namespace()
self._client.create_collection(
collection_name=self.namespace, schema=schema
collection_name=self.final_namespace, schema=schema
)
self._create_indexes_after_collection()
# Load the newly created collection
self._ensure_collection_loaded()
logger.info(f"Successfully force-created collection {self.namespace}")
logger.info(
f"[{self.workspace}] Successfully force-created collection {self.namespace}"
)
except Exception as create_error:
logger.error(
f"Failed to force-create collection {self.namespace}: {create_error}"
f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}"
)
raise
@ -651,11 +702,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self.workspace = "_"
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -699,7 +757,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._create_collection_if_not_exist()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -730,7 +788,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data)
results = self._client.upsert(
collection_name=self.final_namespace, data=list_data
)
return results
async def query(
@ -747,7 +807,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
output_fields = list(self.meta_fields)
results = self._client.search(
collection_name=self.namespace,
collection_name=self.final_namespace,
data=embedding,
limit=top_k,
output_fields=output_fields,
@ -780,21 +840,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity from Milvus collection
result = self._client.delete(
collection_name=self.namespace, pks=[entity_id]
collection_name=self.final_namespace, pks=[entity_id]
)
if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted entity {entity_name}")
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else:
logger.debug(f"Entity {entity_name} not found in storage")
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity
@ -811,31 +875,35 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Find all relations involving this entity
results = self._client.query(
collection_name=self.namespace, filter=expr, output_fields=["id"]
collection_name=self.final_namespace, filter=expr, output_fields=["id"]
)
if not results or len(results) == 0:
logger.debug(f"No relations found for entity {entity_name}")
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
return
# Extract IDs of relations to delete
relation_ids = [item["id"] for item in results]
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
if relation_ids:
delete_result = self._client.delete(
collection_name=self.namespace, pks=relation_ids
collection_name=self.final_namespace, pks=relation_ids
)
logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
f"[{self.workspace}] Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs
@ -848,17 +916,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._ensure_collection_loaded()
# Delete vectors by IDs
result = self._client.delete(collection_name=self.namespace, pks=ids)
result = self._client.delete(collection_name=self.final_namespace, pks=ids)
if result and result.get("delete_count", 0) > 0:
logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
)
else:
logger.debug(f"No vectors were deleted from {self.namespace}")
logger.debug(
f"[{self.workspace}] No vectors were deleted from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
@ -878,7 +950,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Query Milvus for a specific ID
result = self._client.query(
collection_name=self.namespace,
collection_name=self.final_namespace,
filter=f'id == "{id}"',
output_fields=output_fields,
)
@ -888,7 +960,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return result[0]
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -916,14 +990,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Query Milvus with the filter
result = self._client.query(
collection_name=self.namespace,
collection_name=self.final_namespace,
filter=filter_expr,
output_fields=output_fields,
)
return result or []
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return []
async def drop(self) -> dict[str, str]:
@ -938,16 +1014,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""
try:
# Drop the collection and recreate it
if self._client.has_collection(self.namespace):
self._client.drop_collection(self.namespace)
if self._client.has_collection(self.final_namespace):
self._client.drop_collection(self.final_namespace)
# Recreate the collection
self._create_collection_if_not_exist()
logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -107,19 +107,30 @@ class MongoKVStorage(BaseKVStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(
f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
)
self._collection_name = self.namespace
self._collection_name = self.final_namespace
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
logger.debug(
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
)
async def finalize(self):
if self.db is not None:
@ -167,7 +178,7 @@ class MongoKVStorage(BaseKVStorage):
return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -227,10 +238,12 @@ class MongoKVStorage(BaseKVStorage):
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.info(
f"Deleted {result.deleted_count} documents from {self.namespace}"
f"[{self.workspace}] Deleted {result.deleted_count} documents from {self.namespace}"
)
except PyMongoError as e:
logger.error(f"Error deleting documents from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error deleting documents from {self.namespace}: {e}"
)
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection.
@ -243,14 +256,16 @@ class MongoKVStorage(BaseKVStorage):
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)}
@ -287,13 +302,20 @@ class MongoDocStatusStorage(DocStatusStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._collection_name = self.namespace
self._collection_name = self.final_namespace
async def initialize(self):
if self.db is None:
@ -306,7 +328,9 @@ class MongoDocStatusStorage(DocStatusStorage):
# Create pagination indexes for better query performance
await self.create_pagination_indexes_if_not_exists()
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
logger.debug(
f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}"
)
async def finalize(self):
if self.db is not None:
@ -327,7 +351,7 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
update_tasks: list[Any] = []
@ -376,7 +400,9 @@ class MongoDocStatusStorage(DocStatusStorage):
data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}")
logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue
return processed_result
@ -405,7 +431,9 @@ class MongoDocStatusStorage(DocStatusStorage):
data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}")
logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue
return processed_result
@ -424,14 +452,16 @@ class MongoDocStatusStorage(DocStatusStorage):
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None:
@ -450,16 +480,16 @@ class MongoDocStatusStorage(DocStatusStorage):
if not track_id_index_exists:
await self._data.create_index("track_id")
logger.info(
f"Created track_id index for collection {self._collection_name}"
f"[{self.workspace}] Created track_id index for collection {self._collection_name}"
)
else:
logger.debug(
f"track_id index already exists for collection {self._collection_name}"
f"[{self.workspace}] track_id index already exists for collection {self._collection_name}"
)
except PyMongoError as e:
logger.error(
f"Error creating track_id index for {self._collection_name}: {e}"
f"[{self.workspace}] Error creating track_id index for {self._collection_name}: {e}"
)
async def create_pagination_indexes_if_not_exists(self):
@ -492,16 +522,16 @@ class MongoDocStatusStorage(DocStatusStorage):
if index_name not in existing_index_names:
await self._data.create_index(index_info["keys"], name=index_name)
logger.info(
f"Created pagination index '{index_name}' for collection {self._collection_name}"
f"[{self.workspace}] Created pagination index '{index_name}' for collection {self._collection_name}"
)
else:
logger.debug(
f"Pagination index '{index_name}' already exists for collection {self._collection_name}"
f"[{self.workspace}] Pagination index '{index_name}' already exists for collection {self._collection_name}"
)
except PyMongoError as e:
logger.error(
f"Error creating pagination indexes for {self._collection_name}: {e}"
f"[{self.workspace}] Error creating pagination indexes for {self._collection_name}: {e}"
)
async def get_docs_paginated(
@ -586,7 +616,9 @@ class MongoDocStatusStorage(DocStatusStorage):
doc_status = DocProcessingStatus(**data)
documents.append((doc_id, doc_status))
except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}")
logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue
return documents, total_count
@ -650,13 +682,20 @@ class MongoGraphStorage(BaseGraphStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._collection_name = self.namespace
self._collection_name = self.final_namespace
self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self):
@ -668,7 +707,9 @@ class MongoGraphStorage(BaseGraphStorage):
self.edge_collection = await get_or_create_collection(
self.db, self._edge_collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
logger.debug(
f"[{self.workspace}] Use MongoDB as KG {self._collection_name}"
)
async def finalize(self):
if self.db is not None:
@ -1248,7 +1289,9 @@ class MongoGraphStorage(BaseGraphStorage):
# Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label})
if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!")
logger.warning(
f"[{self.workspace}] Starting node with label {node_label} does not exist!"
)
return result
seen_nodes.add(node_label)
@ -1407,14 +1450,14 @@ class MongoGraphStorage(BaseGraphStorage):
duration = time.perf_counter() - start
logger.info(
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
f"[{self.workspace}] Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
)
except PyMongoError as e:
# Handle memory limit errors specifically
if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower():
logger.warning(
f"MongoDB memory limit exceeded, falling back to simple query: {str(e)}"
f"[{self.workspace}] MongoDB memory limit exceeded, falling back to simple query: {str(e)}"
)
# Fallback to a simple query without complex aggregation
try:
@ -1425,12 +1468,14 @@ class MongoGraphStorage(BaseGraphStorage):
)
result.is_truncated = True
logger.info(
f"Fallback query completed | Node count: {len(result.nodes)}"
f"[{self.workspace}] Fallback query completed | Node count: {len(result.nodes)}"
)
except PyMongoError as fallback_error:
logger.error(f"Fallback query also failed: {str(fallback_error)}")
logger.error(
f"[{self.workspace}] Fallback query also failed: {str(fallback_error)}"
)
else:
logger.error(f"MongoDB query failed: {str(e)}")
logger.error(f"[{self.workspace}] MongoDB query failed: {str(e)}")
return result
@ -1444,7 +1489,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args:
nodes: List of node IDs to be deleted
"""
logger.info(f"Deleting {len(nodes)} nodes")
logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes")
if not nodes:
return
@ -1461,7 +1506,7 @@ class MongoGraphStorage(BaseGraphStorage):
# 2. Delete the node documents
await self.collection.delete_many({"_id": {"$in": nodes}})
logger.debug(f"Successfully deleted nodes: {nodes}")
logger.debug(f"[{self.workspace}] Successfully deleted nodes: {nodes}")
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""Delete multiple edges
@ -1469,7 +1514,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
logger.info(f"Deleting {len(edges)} edges")
logger.info(f"[{self.workspace}] Deleting {len(edges)} edges")
if not edges:
return
@ -1484,7 +1529,7 @@ class MongoGraphStorage(BaseGraphStorage):
await self.edge_collection.delete_many({"$or": all_edge_pairs})
logger.debug(f"Successfully deleted edges: {edges}")
logger.debug(f"[{self.workspace}] Successfully deleted edges: {edges}")
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
@ -1527,13 +1572,13 @@ class MongoGraphStorage(BaseGraphStorage):
deleted_count = result.deleted_count
logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}"
f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
)
result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count
logger.info(
f"Dropped {edge_count} edges from graph {self._edge_collection_name}"
f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
)
return {
@ -1541,7 +1586,9 @@ class MongoGraphStorage(BaseGraphStorage):
"message": f"{deleted_count} documents and {edge_count} edges dropped",
}
except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}")
logger.error(
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)}
@ -1582,16 +1629,23 @@ class MongoVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
# Set index name based on workspace for backward compatibility
if effective_workspace:
# Use collection-specific index name for workspaced collections to avoid conflicts
self._index_name = f"vector_knn_index_{self.namespace}"
self._index_name = f"vector_knn_index_{self.final_namespace}"
else:
# Keep original index name for backward compatibility with existing deployments
self._index_name = "vector_knn_index"
@ -1603,7 +1657,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
self._collection_name = self.namespace
self._collection_name = self.final_namespace
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
@ -1614,7 +1668,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Ensure vector index exists
await self.create_vector_index_if_not_exists()
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
logger.debug(
f"[{self.workspace}] Use MongoDB as VDB {self._collection_name}"
)
async def finalize(self):
if self.db is not None:
@ -1629,7 +1685,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
indexes = await indexes_cursor.to_list(length=None)
for index in indexes:
if index["name"] == self._index_name:
logger.info(f"vector index {self._index_name} already exist")
logger.info(
f"[{self.workspace}] vector index {self._index_name} already exist"
)
return
search_index_model = SearchIndexModel(
@ -1648,17 +1706,19 @@ class MongoVectorDBStorage(BaseVectorStorage):
)
await self._data.create_search_index(search_index_model)
logger.info(f"Vector index {self._index_name} created successfully.")
logger.info(
f"[{self.workspace}] Vector index {self._index_name} created successfully."
)
except PyMongoError as e:
error_msg = f"Error creating vector index {self._index_name}: {e}"
error_msg = f"[{self.workspace}] Error creating vector index {self._index_name}: {e}"
logger.error(error_msg)
raise SystemExit(
f"Failed to create MongoDB vector index. Program cannot continue. {error_msg}"
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -1747,7 +1807,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
Args:
ids: List of vector IDs to be deleted
"""
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
logger.debug(
f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}"
)
if not ids:
return
@ -1758,11 +1820,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
try:
result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {result.deleted_count} vectors from {self.namespace}"
)
except PyMongoError as e:
logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}"
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {str(e)}"
)
async def delete_entity(self, entity_name: str) -> None:
@ -1774,16 +1836,22 @@ class MongoVectorDBStorage(BaseVectorStorage):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0:
logger.debug(f"Successfully deleted entity {entity_name}")
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else:
logger.debug(f"Entity {entity_name} not found in storage")
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except PyMongoError as e:
logger.error(f"Error deleting entity {entity_name}: {str(e)}")
logger.error(
f"[{self.workspace}] Error deleting entity {entity_name}: {str(e)}"
)
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity
@ -1799,23 +1867,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
relations = await relations_cursor.to_list(length=None)
if not relations:
logger.debug(f"No relations found for entity {entity_name}")
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
return
# Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations]
logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}"
f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}"
)
# Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}})
logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}")
logger.debug(
f"[{self.workspace}] Deleted {result.deleted_count} relations for {entity_name}"
)
except PyMongoError as e:
logger.error(f"Error deleting relations for {entity_name}: {str(e)}")
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {str(e)}"
)
except PyMongoError as e:
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}")
logger.error(
f"[{self.workspace}] Error searching by prefix in {self.namespace}: {str(e)}"
)
return []
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -1838,7 +1914,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return result_dict
return None
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -1868,7 +1946,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return formatted_results
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return []
async def drop(self) -> dict[str, str]:
@ -1886,14 +1966,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
await self.create_vector_index_if_not_exists()
logger.info(
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
)
return {
"status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated",
}
except PyMongoError as e:
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
logger.error(
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -48,23 +48,26 @@ logging.getLogger("neo4j").setLevel(logging.ERROR)
@dataclass
class Neo4JStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None):
# Check NEO4J_WORKSPACE environment variable and override workspace if set
# Read env and override the arg if present
neo4j_workspace = os.environ.get("NEO4J_WORKSPACE")
if neo4j_workspace and neo4j_workspace.strip():
workspace = neo4j_workspace
# Default to 'base' when both arg and env are empty
if not workspace or not str(workspace).strip():
workspace = "base"
super().__init__(
namespace=namespace,
workspace=workspace or "",
workspace=workspace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
workspace = getattr(self, "workspace", None)
return workspace if workspace else "base"
"""Return workspace label (guaranteed non-empty during initialization)"""
return self.workspace
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
@ -117,6 +120,7 @@ class Neo4JStorage(BaseGraphStorage):
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
@ -140,20 +144,26 @@ class Neo4JStorage(BaseGraphStorage):
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
logger.info(
f"[{self.workspace}] Connected to {database} at {URI}"
)
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
f"[{self.workspace}] "
+ f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
logger.error(
f"[{self.workspace}] Authentication failed for {database} at {URI}"
)
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
f"[{self.workspace}] "
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
@ -161,7 +171,10 @@ class Neo4JStorage(BaseGraphStorage):
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} created".capitalize()
)
connected = True
except (
neo4jExceptions.ClientError,
@ -173,10 +186,12 @@ class Neo4JStorage(BaseGraphStorage):
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
logger.error(
f"[{self.workspace}] Failed to create {database} at {URI}"
)
raise e
if connected:
@ -204,7 +219,7 @@ class Neo4JStorage(BaseGraphStorage):
)
await result.consume()
logger.info(
f"Created index for {workspace_label} nodes on entity_id in {database}"
f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}"
)
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
@ -213,7 +228,9 @@ class Neo4JStorage(BaseGraphStorage):
)
await result.consume()
except Exception as e:
logger.warning(f"Failed to create index: {str(e)}")
logger.warning(
f"[{self.workspace}] Failed to create index: {str(e)}"
)
break
async def finalize(self):
@ -255,7 +272,9 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
raise
@ -293,7 +312,7 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"]
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
raise
@ -328,7 +347,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
@ -346,7 +365,9 @@ class Neo4JStorage(BaseGraphStorage):
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
)
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
@ -415,18 +436,22 @@ class Neo4JStorage(BaseGraphStorage):
record = await result.single()
if not record:
logger.warning(f"No node found with label '{node_id}'")
logger.warning(
f"[{self.workspace}] No node found with label '{node_id}'"
)
return 0
degree = record["degree"]
# logger.debug(
# f"Neo4j query node degree for {node_id} return: {degree}"
# f"[{self.workspace}] Neo4j query node degree for {node_id} return: {degree}"
# )
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
)
raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
@ -459,10 +484,12 @@ class Neo4JStorage(BaseGraphStorage):
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
logger.warning(
f"[{self.workspace}] No node found with label '{nid}'"
)
degrees[nid] = 0
# logger.debug(f"Neo4j batch node degree query returned: {degrees}")
# logger.debug(f"[{self.workspace}] Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@ -546,7 +573,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
)
if records:
try:
@ -563,7 +590,7 @@ class Neo4JStorage(BaseGraphStorage):
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} "
f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} "
f"missing {key}, using default: {default_value}"
)
@ -573,7 +600,7 @@ class Neo4JStorage(BaseGraphStorage):
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {source_node_id} "
f"[{self.workspace}] Error processing edge properties between {source_node_id} "
f"and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
@ -594,7 +621,7 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
@ -701,12 +728,14 @@ class Neo4JStorage(BaseGraphStorage):
return edges
except Exception as e:
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
logger.error(
f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
)
raise
async def get_nodes_edges_batch(
@ -856,7 +885,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
logger.error(f"[{self.workspace}] Error during upsert: {str(e)}")
raise
@retry(
@ -917,7 +946,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}")
raise
async def get_knowledge_graph(
@ -967,7 +996,7 @@ class Neo4JStorage(BaseGraphStorage):
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
f"[{self.workspace}] Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
@ -1034,7 +1063,9 @@ class Neo4JStorage(BaseGraphStorage):
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
logger.debug(
f"[{self.workspace}] No nodes found for entity_id: {node_label}"
)
return result
# If record found, check node count
@ -1043,14 +1074,14 @@ class Neo4JStorage(BaseGraphStorage):
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
f"[{self.workspace}] Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
f"[{self.workspace}] Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
@ -1122,19 +1153,19 @@ class Neo4JStorage(BaseGraphStorage):
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
logger.warning(f"[{self.workspace}] APOC plugin error: {str(e)}")
if node_label != "*":
logger.warning(
"Neo4j: falling back to basic Cypher recursive search..."
f"[{self.workspace}] Neo4j: falling back to basic Cypher recursive search..."
)
return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
f"[{self.workspace}] Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result
@ -1193,7 +1224,7 @@ class Neo4JStorage(BaseGraphStorage):
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
f"[{self.workspace}] Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
@ -1210,7 +1241,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
f"[{self.workspace}] Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
@ -1281,20 +1312,20 @@ class Neo4JStorage(BaseGraphStorage):
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
f"[{self.workspace}] Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
f"[{self.workspace}] Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
f"[{self.workspace}] Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
f"[{self.workspace}] BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
@ -1358,14 +1389,14 @@ class Neo4JStorage(BaseGraphStorage):
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label '{node_id}'")
logger.debug(f"[{self.workspace}] Deleted node with label '{node_id}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
raise
@retry(
@ -1424,14 +1455,16 @@ class Neo4JStorage(BaseGraphStorage):
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise
async def get_all_nodes(self) -> list[dict]:
@ -1501,15 +1534,15 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
logger.info(
f"Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
)
# logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# )
return {
"status": "success",
"message": f"workspace '{workspace_label}' data dropped",
}
except Exception as e:
logger.error(
f"Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -88,11 +88,18 @@ class QdrantVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -113,14 +120,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
self.final_namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
),
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -158,7 +165,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
results = self._client.upsert(
collection_name=self.namespace, points=list_points, wait=True
collection_name=self.final_namespace, points=list_points, wait=True
)
return results
@ -169,14 +176,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
[query], _priority=5
) # higher priority for query
results = self._client.search(
collection_name=self.namespace,
collection_name=self.final_namespace,
query_vector=embedding[0],
limit=top_k,
with_payload=True,
score_threshold=self.cosine_better_than_threshold,
)
logger.debug(f"query result: {results}")
logger.debug(f"[{self.workspace}] query result: {results}")
return [
{
@ -202,17 +209,19 @@ class QdrantVectorDBStorage(BaseVectorStorage):
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection
self._client.delete(
collection_name=self.namespace,
collection_name=self.final_namespace,
points_selector=models.PointIdsList(
points=qdrant_ids,
),
wait=True,
)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name
@ -224,20 +233,22 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity point from the collection
self._client.delete(
collection_name=self.namespace,
collection_name=self.final_namespace,
points_selector=models.PointIdsList(
points=[entity_id],
),
wait=True,
)
logger.debug(f"Successfully deleted entity {entity_name}")
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity
@ -248,7 +259,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
try:
# Find relations where the entity is either source or target
results = self._client.scroll(
collection_name=self.namespace,
collection_name=self.final_namespace,
scroll_filter=models.Filter(
should=[
models.FieldCondition(
@ -270,19 +281,23 @@ class QdrantVectorDBStorage(BaseVectorStorage):
if ids_to_delete:
# Delete the relations
self._client.delete(
collection_name=self.namespace,
collection_name=self.final_namespace,
points_selector=models.PointIdsList(
points=ids_to_delete,
),
wait=True,
)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
@ -299,7 +314,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Retrieve the point by ID
result = self._client.retrieve(
collection_name=self.namespace,
collection_name=self.final_namespace,
ids=[qdrant_id],
with_payload=True,
)
@ -314,7 +329,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return payload
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -335,7 +352,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Retrieve the points by IDs
results = self._client.retrieve(
collection_name=self.namespace,
collection_name=self.final_namespace,
ids=qdrant_ids,
with_payload=True,
)
@ -350,7 +367,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return payloads
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return []
async def drop(self) -> dict[str, str]:
@ -365,13 +384,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
"""
try:
# Delete the collection and recreate it
if self._client.collection_exists(self.namespace):
self._client.delete_collection(self.namespace)
if self._client.collection_exists(self.final_namespace):
self._client.delete_collection(self.final_namespace)
# Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
self.final_namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
@ -379,9 +398,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
logger.info(
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -141,11 +141,18 @@ class RedisKVStorage(BaseKVStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
@ -159,13 +166,15 @@ class RedisKVStorage(BaseKVStorage):
self._pool = RedisConnectionManager.get_pool(self._redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis KV storage for {self.namespace} using shared connection pool"
f"[{self.workspace}] Initialized Redis KV storage for {self.namespace} using shared connection pool"
)
except Exception as e:
# Clean up on initialization failure
if self._redis_url:
RedisConnectionManager.release_pool(self._redis_url)
logger.error(f"Failed to initialize Redis KV storage: {e}")
logger.error(
f"[{self.workspace}] Failed to initialize Redis KV storage: {e}"
)
raise
async def initialize(self):
@ -177,10 +186,12 @@ class RedisKVStorage(BaseKVStorage):
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(f"Connected to Redis for namespace {self.namespace}")
logger.info(
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
)
self._initialized = True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
# Clean up on connection failure
await self.close()
raise
@ -190,7 +201,9 @@ class RedisKVStorage(BaseKVStorage):
try:
await self._migrate_legacy_cache_structure()
except Exception as e:
logger.error(f"Failed to migrate legacy cache structure: {e}")
logger.error(
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
)
# Don't fail initialization for migration errors, just log them
@asynccontextmanager
@ -203,14 +216,18 @@ class RedisKVStorage(BaseKVStorage):
# Use the existing Redis instance with shared pool
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Redis connection error in {self.namespace}: {e}"
)
raise
except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Redis operation error in {self.namespace}: {e}"
)
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis operation for {self.namespace}: {e}"
f"[{self.workspace}] Unexpected error in Redis operation for {self.namespace}: {e}"
)
raise
@ -219,9 +236,11 @@ class RedisKVStorage(BaseKVStorage):
if hasattr(self, "_redis") and self._redis:
try:
await self._redis.close()
logger.debug(f"Closed Redis connection for {self.namespace}")
logger.debug(
f"[{self.workspace}] Closed Redis connection for {self.namespace}"
)
except Exception as e:
logger.error(f"Error closing Redis connection: {e}")
logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
finally:
self._redis = None
@ -230,7 +249,7 @@ class RedisKVStorage(BaseKVStorage):
RedisConnectionManager.release_pool(self._redis_url)
self._pool = None
logger.debug(
f"Released Redis connection pool reference for {self.namespace}"
f"[{self.workspace}] Released Redis connection pool reference for {self.namespace}"
)
async def __aenter__(self):
@ -245,7 +264,7 @@ class RedisKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
data = await redis.get(f"{self.final_namespace}:{id}")
if data:
result = json.loads(data)
# Ensure time fields are present, provide default values for old data
@ -254,7 +273,7 @@ class RedisKVStorage(BaseKVStorage):
return result
return None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
return None
@redis_retry
@ -263,7 +282,7 @@ class RedisKVStorage(BaseKVStorage):
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
pipe.get(f"{self.final_namespace}:{id}")
results = await pipe.execute()
processed_results = []
@ -279,7 +298,7 @@ class RedisKVStorage(BaseKVStorage):
return processed_results
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
logger.error(f"[{self.workspace}] JSON decode error in batch get: {e}")
return [None] * len(ids)
async def get_all(self) -> dict[str, Any]:
@ -291,7 +310,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
try:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
keys = await redis.keys(f"{self.final_namespace}:*")
if not keys:
return {}
@ -315,12 +334,16 @@ class RedisKVStorage(BaseKVStorage):
data.setdefault("update_time", 0)
result[key_id] = data
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}")
logger.error(
f"[{self.workspace}] JSON decode error for key {key}: {e}"
)
continue
return result
except Exception as e:
logger.error(f"Error getting all data from Redis: {e}")
logger.error(
f"[{self.workspace}] Error getting all data from Redis: {e}"
)
return {}
async def filter_keys(self, keys: set[str]) -> set[str]:
@ -328,7 +351,7 @@ class RedisKVStorage(BaseKVStorage):
pipe = redis.pipeline()
keys_list = list(keys) # Convert set to list for indexing
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
pipe.exists(f"{self.final_namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
@ -348,13 +371,13 @@ class RedisKVStorage(BaseKVStorage):
# Check which keys already exist to determine create vs update
pipe = redis.pipeline()
for k in data.keys():
pipe.exists(f"{self.namespace}:{k}")
pipe.exists(f"{self.final_namespace}:{k}")
exists_results = await pipe.execute()
# Add timestamps to data
for i, (k, v) in enumerate(data.items()):
# For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace:
if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v:
v["llm_cache_list"] = []
@ -370,11 +393,11 @@ class RedisKVStorage(BaseKVStorage):
# Store the data
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
await pipe.execute()
except json.JSONDecodeError as e:
logger.error(f"JSON decode error during upsert: {e}")
logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
raise
async def index_done_callback(self) -> None:
@ -389,12 +412,12 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
pipe.delete(f"{self.final_namespace}:{id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
f"[{self.workspace}] Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def drop(self) -> dict[str, str]:
@ -406,7 +429,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
try:
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
pattern = f"{self.final_namespace}:*"
cursor = 0
deleted_count = 0
@ -423,14 +446,18 @@ class RedisKVStorage(BaseKVStorage):
if cursor == 0:
break
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
logger.info(
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
)
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self):
@ -445,7 +472,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis:
# Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*")
keys = await redis.keys(f"{self.final_namespace}:*")
if not keys:
return
@ -480,7 +507,7 @@ class RedisKVStorage(BaseKVStorage):
# If we found any flattened keys, assume migration is already done
if has_flattened_keys:
logger.debug(
f"Found flattened cache keys in {self.namespace}, skipping migration"
f"[{self.workspace}] Found flattened cache keys in {self.namespace}, skipping migration"
)
return
@ -499,7 +526,7 @@ class RedisKVStorage(BaseKVStorage):
for cache_hash, cache_entry in nested_data.items():
cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash)
full_key = f"{self.namespace}:{flattened_key}"
full_key = f"{self.final_namespace}:{flattened_key}"
pipe.set(full_key, json.dumps(cache_entry))
migration_count += 1
@ -507,7 +534,7 @@ class RedisKVStorage(BaseKVStorage):
if migration_count > 0:
logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis"
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure in Redis"
)
@ -534,11 +561,20 @@ class RedisDocStatusStorage(DocStatusStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Build namespace with workspace prefix for data isolation
# Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'")
# When workspace is empty, keep the original namespace unchanged
self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(
f"[{self.workspace}] Final namespace with workspace prefix: '{self.namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(
f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
)
self._redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
@ -552,13 +588,15 @@ class RedisDocStatusStorage(DocStatusStorage):
self._pool = RedisConnectionManager.get_pool(self._redis_url)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool"
f"[{self.workspace}] Initialized Redis doc status storage for {self.namespace} using shared connection pool"
)
except Exception as e:
# Clean up on initialization failure
if self._redis_url:
RedisConnectionManager.release_pool(self._redis_url)
logger.error(f"Failed to initialize Redis doc status storage: {e}")
logger.error(
f"[{self.workspace}] Failed to initialize Redis doc status storage: {e}"
)
raise
async def initialize(self):
@ -570,11 +608,13 @@ class RedisDocStatusStorage(DocStatusStorage):
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(
f"Connected to Redis for doc status namespace {self.namespace}"
f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
)
self._initialized = True
except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}")
logger.error(
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
)
# Clean up on connection failure
await self.close()
raise
@ -589,14 +629,18 @@ class RedisDocStatusStorage(DocStatusStorage):
# Use the existing Redis instance with shared pool
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in doc status {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Redis connection error in doc status {self.namespace}: {e}"
)
raise
except RedisError as e:
logger.error(f"Redis operation error in doc status {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Redis operation error in doc status {self.namespace}: {e}"
)
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}"
f"[{self.workspace}] Unexpected error in Redis doc status operation for {self.namespace}: {e}"
)
raise
@ -605,9 +649,11 @@ class RedisDocStatusStorage(DocStatusStorage):
if hasattr(self, "_redis") and self._redis:
try:
await self._redis.close()
logger.debug(f"Closed Redis connection for doc status {self.namespace}")
logger.debug(
f"[{self.workspace}] Closed Redis connection for doc status {self.namespace}"
)
except Exception as e:
logger.error(f"Error closing Redis connection: {e}")
logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
finally:
self._redis = None
@ -616,7 +662,7 @@ class RedisDocStatusStorage(DocStatusStorage):
RedisConnectionManager.release_pool(self._redis_url)
self._pool = None
logger.debug(
f"Released Redis connection pool reference for doc status {self.namespace}"
f"[{self.workspace}] Released Redis connection pool reference for doc status {self.namespace}"
)
async def __aenter__(self):
@ -633,7 +679,7 @@ class RedisDocStatusStorage(DocStatusStorage):
pipe = redis.pipeline()
keys_list = list(keys)
for key in keys_list:
pipe.exists(f"{self.namespace}:{key}")
pipe.exists(f"{self.final_namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
@ -645,7 +691,7 @@ class RedisDocStatusStorage(DocStatusStorage):
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
pipe.get(f"{self.final_namespace}:{id}")
results = await pipe.execute()
for result_data in results:
@ -653,10 +699,12 @@ class RedisDocStatusStorage(DocStatusStorage):
try:
result.append(json.loads(result_data))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_by_ids: {e}")
logger.error(
f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
)
continue
except Exception as e:
logger.error(f"Error in get_by_ids: {e}")
logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
return result
async def get_status_counts(self) -> dict[str, int]:
@ -668,7 +716,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
cursor, match=f"{self.final_namespace}:*", count=1000
)
if keys:
# Get all values in batch
@ -691,7 +739,7 @@ class RedisDocStatusStorage(DocStatusStorage):
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting status counts: {e}")
logger.error(f"[{self.workspace}] Error getting status counts: {e}")
return counts
@ -706,7 +754,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
cursor, match=f"{self.final_namespace}:*", count=1000
)
if keys:
# Get all values in batch
@ -740,14 +788,14 @@ class RedisDocStatusStorage(DocStatusStorage):
result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error processing document {key}: {e}"
f"[{self.workspace}] Error processing document {key}: {e}"
)
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting docs by status: {e}")
logger.error(f"[{self.workspace}] Error getting docs by status: {e}")
return result
@ -762,7 +810,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
cursor, match=f"{self.final_namespace}:*", count=1000
)
if keys:
# Get all values in batch
@ -796,14 +844,14 @@ class RedisDocStatusStorage(DocStatusStorage):
result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error processing document {key}: {e}"
f"[{self.workspace}] Error processing document {key}: {e}"
)
continue
if cursor == 0:
break
except Exception as e:
logger.error(f"Error getting docs by track_id: {e}")
logger.error(f"[{self.workspace}] Error getting docs by track_id: {e}")
return result
@ -817,7 +865,9 @@ class RedisDocStatusStorage(DocStatusStorage):
if not data:
return
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._get_redis_connection() as redis:
try:
# Ensure chunks_list field exists for new documents
@ -827,20 +877,20 @@ class RedisDocStatusStorage(DocStatusStorage):
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
await pipe.execute()
except json.JSONDecodeError as e:
logger.error(f"JSON decode error during upsert: {e}")
logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
raise
@redis_retry
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
data = await redis.get(f"{self.final_namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
return None
async def delete(self, doc_ids: list[str]) -> None:
@ -851,12 +901,12 @@ class RedisDocStatusStorage(DocStatusStorage):
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}")
pipe.delete(f"{self.final_namespace}:{doc_id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
f"[{self.workspace}] Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
)
async def get_docs_paginated(
@ -903,7 +953,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000
cursor, match=f"{self.final_namespace}:*", count=1000
)
if keys:
# Get all values in batch
@ -950,7 +1000,7 @@ class RedisDocStatusStorage(DocStatusStorage):
except (json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error processing document {key}: {e}"
f"[{self.workspace}] Error processing document {key}: {e}"
)
continue
@ -958,7 +1008,7 @@ class RedisDocStatusStorage(DocStatusStorage):
break
except Exception as e:
logger.error(f"Error getting paginated docs: {e}")
logger.error(f"[{self.workspace}] Error getting paginated docs: {e}")
return [], 0
# Sort documents using the separate sort key
@ -996,7 +1046,7 @@ class RedisDocStatusStorage(DocStatusStorage):
try:
async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*"
pattern = f"{self.final_namespace}:*"
cursor = 0
deleted_count = 0
@ -1014,9 +1064,11 @@ class RedisDocStatusStorage(DocStatusStorage):
break
logger.info(
f"Dropped {deleted_count} doc status keys from {self.namespace}"
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Iterable
# All namespace should not be changed
class NameSpace:
KV_STORE_FULL_DOCS = "full_docs"
KV_STORE_TEXT_CHUNKS = "text_chunks"

View file

@ -1248,7 +1248,7 @@ async def merge_nodes_and_edges(
semaphore = asyncio.Semaphore(graph_max_async)
# ===== Phase 1: Process all entities concurrently =====
log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})"
log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
@ -1312,7 +1312,7 @@ async def merge_nodes_and_edges(
processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})"
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
@ -1422,6 +1422,12 @@ async def merge_nodes_and_edges(
relation_pair = tuple(sorted([src_id, tgt_id]))
final_relation_pairs.add(relation_pair)
log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update storage
if final_entity_names:
await full_entities_storage.upsert(