Refac: Add workspace infomation to all logger output for all storage type

This commit is contained in:
yangdx 2025-08-12 01:19:09 +08:00
parent 44204abef7
commit 095e0cbfa2
11 changed files with 734 additions and 417 deletions

View file

@ -252,6 +252,7 @@ POSTGRES_IVFFLAT_LISTS=100
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
NEO4J_USERNAME=neo4j NEO4J_USERNAME=neo4j
NEO4J_PASSWORD='your_password' NEO4J_PASSWORD='your_password'
# NEO4J_DATABASE=chunk_entity_relation
NEO4J_MAX_CONNECTION_POOL_SIZE=100 NEO4J_MAX_CONNECTION_POOL_SIZE=100
NEO4J_CONNECTION_TIMEOUT=30 NEO4J_CONNECTION_TIMEOUT=30
NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30 NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30

View file

@ -33,19 +33,18 @@ class JsonDocStatusStorage(DocStatusStorage):
if self.workspace: if self.workspace:
# Include workspace in the file path for data isolation # Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace) workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True) self.final_namespace = f"{self.workspace}_{self.namespace}"
self._file_name = os.path.join(
workspace_dir, f"kv_store_{self.namespace}.json"
)
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
self._file_name = os.path.join( self.final_namespace = self.namespace
working_dir, f"kv_store_{self.namespace}.json" self.workspace = "_"
) workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None self._data = None
self._storage_lock = None self._storage_lock = None
self.storage_updated = None self.storage_updated = None
self.final_namespace = f"{self.workspace}_{self.namespace}"
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
@ -60,7 +59,7 @@ class JsonDocStatusStorage(DocStatusStorage):
async with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
logger.info( logger.info(
f"Process {os.getpid()} doc status load {self.final_namespace} with {len(loaded_data)} records" f"[{self.workspace}] Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
) )
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
@ -108,7 +107,9 @@ class JsonDocStatusStorage(DocStatusStorage):
data["error_msg"] = None data["error_msg"] = None
result[k] = DocProcessingStatus(**data) result[k] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}") logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue continue
return result return result
@ -135,7 +136,9 @@ class JsonDocStatusStorage(DocStatusStorage):
data["error_msg"] = None data["error_msg"] = None
result[k] = DocProcessingStatus(**data) result[k] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}") logger.error(
f"[{self.workspace}] Missing required field for document {k}: {e}"
)
continue continue
return result return result
@ -146,7 +149,7 @@ class JsonDocStatusStorage(DocStatusStorage):
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
logger.debug( logger.debug(
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.final_namespace}" f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
) )
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.final_namespace) await clear_all_update_flags(self.final_namespace)
@ -159,7 +162,9 @@ class JsonDocStatusStorage(DocStatusStorage):
""" """
if not data: if not data:
return return
logger.debug(f"Inserting {len(data)} records to {self.final_namespace}") logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._storage_lock: async with self._storage_lock:
# Ensure chunks_list field exists for new documents # Ensure chunks_list field exists for new documents
for doc_id, doc_data in data.items(): for doc_id, doc_data in data.items():
@ -242,7 +247,9 @@ class JsonDocStatusStorage(DocStatusStorage):
all_docs.append((doc_id, doc_status)) all_docs.append((doc_id, doc_status))
except KeyError as e: except KeyError as e:
logger.error(f"Error processing document {doc_id}: {e}") logger.error(
f"[{self.workspace}] Error processing document {doc_id}: {e}"
)
continue continue
# Sort documents # Sort documents
@ -321,8 +328,10 @@ class JsonDocStatusStorage(DocStatusStorage):
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.final_namespace)
await self.index_done_callback() await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.final_namespace}") logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping {self.final_namespace}: {e}") logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -29,19 +29,19 @@ class JsonKVStorage(BaseKVStorage):
if self.workspace: if self.workspace:
# Include workspace in the file path for data isolation # Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace) workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True) self.final_namespace = f"{self.workspace}_{self.namespace}"
self._file_name = os.path.join(
workspace_dir, f"kv_store_{self.namespace}.json"
)
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
self._file_name = os.path.join( workspace_dir = working_dir
working_dir, f"kv_store_{self.namespace}.json" self.final_namespace = self.namespace
) self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
self._data = None self._data = None
self._storage_lock = None self._storage_lock = None
self.storage_updated = None self.storage_updated = None
self.final_namespace = f"{self.workspace}_{self.namespace}"
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
@ -64,7 +64,7 @@ class JsonKVStorage(BaseKVStorage):
data_count = len(loaded_data) data_count = len(loaded_data)
logger.info( logger.info(
f"Process {os.getpid()} KV load {self.final_namespace} with {data_count} records" f"[{self.workspace}] Process {os.getpid()} KV load {self.namespace} with {data_count} records"
) )
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@ -78,7 +78,7 @@ class JsonKVStorage(BaseKVStorage):
data_count = len(data_dict) data_count = len(data_dict)
logger.debug( logger.debug(
f"Process {os.getpid()} KV writting {data_count} records to {self.final_namespace}" f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
) )
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.final_namespace) await clear_all_update_flags(self.final_namespace)
@ -151,12 +151,14 @@ class JsonKVStorage(BaseKVStorage):
current_time = int(time.time()) # Get current Unix timestamp current_time = int(time.time()) # Get current Unix timestamp
logger.debug(f"Inserting {len(data)} records to {self.final_namespace}") logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._storage_lock: async with self._storage_lock:
# Add timestamps to data based on whether key exists # Add timestamps to data based on whether key exists
for k, v in data.items(): for k, v in data.items():
# For text_chunks namespace, ensure llm_cache_list field exists # For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace: if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v: if "llm_cache_list" not in v:
v["llm_cache_list"] = [] v["llm_cache_list"] = []
@ -215,10 +217,12 @@ class JsonKVStorage(BaseKVStorage):
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.final_namespace)
await self.index_done_callback() await self.index_done_callback()
logger.info(f"Process {os.getpid()} drop {self.final_namespace}") logger.info(
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}"
)
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping {self.final_namespace}: {e}") logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self, data: dict) -> dict: async def _migrate_legacy_cache_structure(self, data: dict) -> dict:
@ -263,7 +267,7 @@ class JsonKVStorage(BaseKVStorage):
if migration_count > 0: if migration_count > 0:
logger.info( logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure" f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
) )
# Persist migrated data immediately # Persist migrated data immediately
write_json(migrated_data, self._file_name) write_json(migrated_data, self._file_name)

View file

@ -34,21 +34,25 @@ config.read("config.ini", "utf-8")
@dataclass @dataclass
class MemgraphStorage(BaseGraphStorage): class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None): def __init__(self, namespace, global_config, embedding_func, workspace=None):
# Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base'
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip(): if memgraph_workspace and memgraph_workspace.strip():
workspace = memgraph_workspace workspace = memgraph_workspace
if not workspace or not str(workspace).strip():
workspace = "base"
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
workspace=workspace or "", workspace=workspace,
global_config=global_config, global_config=global_config,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
def _get_workspace_label(self) -> str: def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty""" """Return workspace label (guaranteed non-empty during initialization)"""
workspace = getattr(self, "workspace", None) return self.workspace
return workspace if workspace else "base"
async def initialize(self): async def initialize(self):
URI = os.environ.get( URI = os.environ.get(
@ -79,17 +83,19 @@ class MemgraphStorage(BaseGraphStorage):
f"""CREATE INDEX ON :{workspace_label}(entity_id)""" f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
) )
logger.info( logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph." f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
) )
except Exception as e: except Exception as e:
# Index may already exist, which is not an error # Index may already exist, which is not an error
logger.warning( logger.warning(
f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
) )
await session.run("RETURN 1") await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}") logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Memgraph at {URI}: {e}") logger.error(
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
)
raise raise
async def finalize(self): async def finalize(self):
@ -134,7 +140,9 @@ class MemgraphStorage(BaseGraphStorage):
single_result["node_exists"] if single_result is not None else False single_result["node_exists"] if single_result is not None else False
) )
except Exception as e: except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error await result.consume() # Ensure the result is consumed even on error
raise raise
@ -177,7 +185,7 @@ class MemgraphStorage(BaseGraphStorage):
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error await result.consume() # Ensure the result is consumed even on error
raise raise
@ -215,7 +223,7 @@ class MemgraphStorage(BaseGraphStorage):
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node." f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
) )
if records: if records:
node = records[0]["n"] node = records[0]["n"]
@ -232,7 +240,9 @@ class MemgraphStorage(BaseGraphStorage):
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
)
raise raise
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
@ -268,7 +278,9 @@ class MemgraphStorage(BaseGraphStorage):
record = await result.single() record = await result.single()
if not record: if not record:
logger.warning(f"No node found with label '{node_id}'") logger.warning(
f"[{self.workspace}] No node found with label '{node_id}'"
)
return 0 return 0
degree = record["degree"] degree = record["degree"]
@ -276,7 +288,9 @@ class MemgraphStorage(BaseGraphStorage):
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
)
raise raise
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
@ -310,7 +324,7 @@ class MemgraphStorage(BaseGraphStorage):
await result.consume() await result.consume()
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"Error getting all labels: {str(e)}") logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error await result.consume() # Ensure the result is consumed even on error
raise raise
@ -370,12 +384,14 @@ class MemgraphStorage(BaseGraphStorage):
return edges return edges
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error await results.consume() # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
)
raise raise
async def get_edge( async def get_edge(
@ -424,13 +440,13 @@ class MemgraphStorage(BaseGraphStorage):
if key not in edge_result: if key not in edge_result:
edge_result[key] = default_value edge_result[key] = default_value
logger.warning( logger.warning(
f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
) )
return edge_result return edge_result
return None return None
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure the result is consumed even on error await result.consume() # Ensure the result is consumed even on error
raise raise
@ -463,7 +479,7 @@ class MemgraphStorage(BaseGraphStorage):
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
logger.debug( logger.debug(
f"Attempting node upsert, attempt {attempt + 1}/{max_retries}" f"[{self.workspace}] Attempting node upsert, attempt {attempt + 1}/{max_retries}"
) )
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
@ -504,20 +520,24 @@ class MemgraphStorage(BaseGraphStorage):
initial_wait_time * (backoff_factor**attempt) + jitter initial_wait_time * (backoff_factor**attempt) + jitter
) )
logger.warning( logger.warning(
f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" f"[{self.workspace}] Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
) )
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.error( logger.error(
f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}" f"[{self.workspace}] Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
) )
raise raise
else: else:
# Non-transient error, don't retry # Non-transient error, don't retry
logger.error(f"Non-transient error during node upsert: {str(e)}") logger.error(
f"[{self.workspace}] Non-transient error during node upsert: {str(e)}"
)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during node upsert: {str(e)}") logger.error(
f"[{self.workspace}] Unexpected error during node upsert: {str(e)}"
)
raise raise
async def upsert_edge( async def upsert_edge(
@ -552,7 +572,7 @@ class MemgraphStorage(BaseGraphStorage):
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
logger.debug( logger.debug(
f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}" f"[{self.workspace}] Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
) )
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
@ -602,20 +622,24 @@ class MemgraphStorage(BaseGraphStorage):
initial_wait_time * (backoff_factor**attempt) + jitter initial_wait_time * (backoff_factor**attempt) + jitter
) )
logger.warning( logger.warning(
f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" f"[{self.workspace}] Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
) )
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.error( logger.error(
f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}" f"[{self.workspace}] Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
) )
raise raise
else: else:
# Non-transient error, don't retry # Non-transient error, don't retry
logger.error(f"Non-transient error during edge upsert: {str(e)}") logger.error(
f"[{self.workspace}] Non-transient error during edge upsert: {str(e)}"
)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during edge upsert: {str(e)}") logger.error(
f"[{self.workspace}] Unexpected error during edge upsert: {str(e)}"
)
raise raise
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
@ -639,14 +663,14 @@ class MemgraphStorage(BaseGraphStorage):
DETACH DELETE n DETACH DELETE n
""" """
result = await tx.run(query, entity_id=node_id) result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label {node_id}") logger.debug(f"[{self.workspace}] Deleted node with label {node_id}")
await result.consume() await result.consume()
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete) await session.execute_write(_do_delete)
except Exception as e: except Exception as e:
logger.error(f"Error during node deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
raise raise
async def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: list[str]):
@ -686,14 +710,16 @@ class MemgraphStorage(BaseGraphStorage):
result = await tx.run( result = await tx.run(
query, source_entity_id=source, target_entity_id=target query, source_entity_id=source, target_entity_id=target
) )
logger.debug(f"Deleted edge from '{source}' to '{target}'") logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge) await session.execute_write(_do_delete_edge)
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise raise
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -720,12 +746,12 @@ class MemgraphStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
await result.consume() await result.consume()
logger.info( logger.info(
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
) )
return {"status": "success", "message": "workspace data dropped"} return {"status": "success", "message": "workspace data dropped"}
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@ -945,14 +971,16 @@ class MemgraphStorage(BaseGraphStorage):
# If no record found, return empty KnowledgeGraph # If no record found, return empty KnowledgeGraph
if not record: if not record:
logger.debug(f"No nodes found for entity_id: {node_label}") logger.debug(
f"[{self.workspace}] No nodes found for entity_id: {node_label}"
)
return result return result
# Check if the result was truncated # Check if the result was truncated
if record.get("is_truncated"): if record.get("is_truncated"):
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes" f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes"
) )
finally: finally:
@ -990,11 +1018,13 @@ class MemgraphStorage(BaseGraphStorage):
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
except Exception as e: except Exception as e:
logger.warning(f"Memgraph error during subgraph query: {str(e)}") logger.warning(
f"[{self.workspace}] Memgraph error during subgraph query: {str(e)}"
)
return result return result

View file

@ -37,7 +37,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
] ]
# Determine specific fields based on namespace # Determine specific fields based on namespace
if "entities" in self.namespace.lower(): if self.namespace.endswith("entities"):
specific_fields = [ specific_fields = [
FieldSchema( FieldSchema(
name="entity_name", name="entity_name",
@ -54,7 +54,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
] ]
description = "LightRAG entities vector storage" description = "LightRAG entities vector storage"
elif "relationships" in self.namespace.lower(): elif self.namespace.endswith("relationships"):
specific_fields = [ specific_fields = [
FieldSchema( FieldSchema(
name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True
@ -71,7 +71,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
] ]
description = "LightRAG relationships vector storage" description = "LightRAG relationships vector storage"
elif "chunks" in self.namespace.lower(): elif self.namespace.endswith("chunks"):
specific_fields = [ specific_fields = [
FieldSchema( FieldSchema(
name="full_doc_id", name="full_doc_id",
@ -147,7 +147,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""Fallback method to create vector index using direct API""" """Fallback method to create vector index using direct API"""
try: try:
self._client.create_index( self._client.create_index(
collection_name=self.namespace, collection_name=self.final_namespace,
field_name="vector", field_name="vector",
index_params={ index_params={
"index_type": "HNSW", "index_type": "HNSW",
@ -155,29 +155,35 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"params": {"M": 16, "efConstruction": 256}, "params": {"M": 16, "efConstruction": 256},
}, },
) )
logger.debug("Created vector index using fallback method") logger.debug(
f"[{self.workspace}] Created vector index using fallback method"
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to create vector index using fallback method: {e}") logger.warning(
f"[{self.workspace}] Failed to create vector index using fallback method: {e}"
)
def _create_scalar_index_fallback(self, field_name: str, index_type: str): def _create_scalar_index_fallback(self, field_name: str, index_type: str):
"""Fallback method to create scalar index using direct API""" """Fallback method to create scalar index using direct API"""
# Skip unsupported index types # Skip unsupported index types
if index_type == "SORTED": if index_type == "SORTED":
logger.info( logger.info(
f"Skipping SORTED index for {field_name} (not supported in this Milvus version)" f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)"
) )
return return
try: try:
self._client.create_index( self._client.create_index(
collection_name=self.namespace, collection_name=self.final_namespace,
field_name=field_name, field_name=field_name,
index_params={"index_type": index_type}, index_params={"index_type": index_type},
) )
logger.debug(f"Created {field_name} index using fallback method") logger.debug(
f"[{self.workspace}] Created {field_name} index using fallback method"
)
except Exception as e: except Exception as e:
logger.info( logger.info(
f"Could not create {field_name} index using fallback method: {e}" f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}"
) )
def _create_indexes_after_collection(self): def _create_indexes_after_collection(self):
@ -198,15 +204,19 @@ class MilvusVectorDBStorage(BaseVectorStorage):
params={"M": 16, "efConstruction": 256}, params={"M": 16, "efConstruction": 256},
) )
self._client.create_index( self._client.create_index(
collection_name=self.namespace, index_params=vector_index collection_name=self.final_namespace, index_params=vector_index
)
logger.debug(
f"[{self.workspace}] Created vector index using IndexParams"
) )
logger.debug("Created vector index using IndexParams")
except Exception as e: except Exception as e:
logger.debug(f"IndexParams method failed for vector index: {e}") logger.debug(
f"[{self.workspace}] IndexParams method failed for vector index: {e}"
)
self._create_vector_index_fallback() self._create_vector_index_fallback()
# Create scalar indexes based on namespace # Create scalar indexes based on namespace
if "entities" in self.namespace.lower(): if self.namespace.endswith("entities"):
# Create indexes for entity fields # Create indexes for entity fields
try: try:
entity_name_index = self._get_index_params() entity_name_index = self._get_index_params()
@ -214,14 +224,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="entity_name", index_type="INVERTED" field_name="entity_name", index_type="INVERTED"
) )
self._client.create_index( self._client.create_index(
collection_name=self.namespace, collection_name=self.final_namespace,
index_params=entity_name_index, index_params=entity_name_index,
) )
except Exception as e: except Exception as e:
logger.debug(f"IndexParams method failed for entity_name: {e}") logger.debug(
f"[{self.workspace}] IndexParams method failed for entity_name: {e}"
)
self._create_scalar_index_fallback("entity_name", "INVERTED") self._create_scalar_index_fallback("entity_name", "INVERTED")
elif "relationships" in self.namespace.lower(): elif self.namespace.endswith("relationships"):
# Create indexes for relationship fields # Create indexes for relationship fields
try: try:
src_id_index = self._get_index_params() src_id_index = self._get_index_params()
@ -229,10 +241,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="src_id", index_type="INVERTED" field_name="src_id", index_type="INVERTED"
) )
self._client.create_index( self._client.create_index(
collection_name=self.namespace, index_params=src_id_index collection_name=self.final_namespace,
index_params=src_id_index,
) )
except Exception as e: except Exception as e:
logger.debug(f"IndexParams method failed for src_id: {e}") logger.debug(
f"[{self.workspace}] IndexParams method failed for src_id: {e}"
)
self._create_scalar_index_fallback("src_id", "INVERTED") self._create_scalar_index_fallback("src_id", "INVERTED")
try: try:
@ -241,13 +256,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="tgt_id", index_type="INVERTED" field_name="tgt_id", index_type="INVERTED"
) )
self._client.create_index( self._client.create_index(
collection_name=self.namespace, index_params=tgt_id_index collection_name=self.final_namespace,
index_params=tgt_id_index,
) )
except Exception as e: except Exception as e:
logger.debug(f"IndexParams method failed for tgt_id: {e}") logger.debug(
f"[{self.workspace}] IndexParams method failed for tgt_id: {e}"
)
self._create_scalar_index_fallback("tgt_id", "INVERTED") self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower(): elif self.namespace.endswith("chunks"):
# Create indexes for chunk fields # Create indexes for chunk fields
try: try:
doc_id_index = self._get_index_params() doc_id_index = self._get_index_params()
@ -255,10 +273,13 @@ class MilvusVectorDBStorage(BaseVectorStorage):
field_name="full_doc_id", index_type="INVERTED" field_name="full_doc_id", index_type="INVERTED"
) )
self._client.create_index( self._client.create_index(
collection_name=self.namespace, index_params=doc_id_index collection_name=self.final_namespace,
index_params=doc_id_index,
) )
except Exception as e: except Exception as e:
logger.debug(f"IndexParams method failed for full_doc_id: {e}") logger.debug(
f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}"
)
self._create_scalar_index_fallback("full_doc_id", "INVERTED") self._create_scalar_index_fallback("full_doc_id", "INVERTED")
# No common indexes needed # No common indexes needed
@ -266,25 +287,29 @@ class MilvusVectorDBStorage(BaseVectorStorage):
else: else:
# Fallback to direct API calls if IndexParams is not available # Fallback to direct API calls if IndexParams is not available
logger.info( logger.info(
f"IndexParams not available, using fallback methods for {self.namespace}" f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}"
) )
# Create vector index using fallback # Create vector index using fallback
self._create_vector_index_fallback() self._create_vector_index_fallback()
# Create scalar indexes using fallback # Create scalar indexes using fallback
if "entities" in self.namespace.lower(): if self.namespace.endswith("entities"):
self._create_scalar_index_fallback("entity_name", "INVERTED") self._create_scalar_index_fallback("entity_name", "INVERTED")
elif "relationships" in self.namespace.lower(): elif self.namespace.endswith("relationships"):
self._create_scalar_index_fallback("src_id", "INVERTED") self._create_scalar_index_fallback("src_id", "INVERTED")
self._create_scalar_index_fallback("tgt_id", "INVERTED") self._create_scalar_index_fallback("tgt_id", "INVERTED")
elif "chunks" in self.namespace.lower(): elif self.namespace.endswith("chunks"):
self._create_scalar_index_fallback("full_doc_id", "INVERTED") self._create_scalar_index_fallback("full_doc_id", "INVERTED")
logger.info(f"Created indexes for collection: {self.namespace}") logger.info(
f"[{self.workspace}] Created indexes for collection: {self.namespace}"
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to create some indexes for {self.namespace}: {e}") logger.warning(
f"[{self.workspace}] Failed to create some indexes for {self.namespace}: {e}"
)
def _get_required_fields_for_namespace(self) -> dict: def _get_required_fields_for_namespace(self) -> dict:
"""Get required core field definitions for current namespace""" """Get required core field definitions for current namespace"""
@ -297,18 +322,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
} }
# Add specific fields based on namespace # Add specific fields based on namespace
if "entities" in self.namespace.lower(): if self.namespace.endswith("entities"):
specific_fields = { specific_fields = {
"entity_name": {"type": "VarChar"}, "entity_name": {"type": "VarChar"},
"file_path": {"type": "VarChar"}, "file_path": {"type": "VarChar"},
} }
elif "relationships" in self.namespace.lower(): elif self.namespace.endswith("relationships"):
specific_fields = { specific_fields = {
"src_id": {"type": "VarChar"}, "src_id": {"type": "VarChar"},
"tgt_id": {"type": "VarChar"}, "tgt_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"}, "file_path": {"type": "VarChar"},
} }
elif "chunks" in self.namespace.lower(): elif self.namespace.endswith("chunks"):
specific_fields = { specific_fields = {
"full_doc_id": {"type": "VarChar"}, "full_doc_id": {"type": "VarChar"},
"file_path": {"type": "VarChar"}, "file_path": {"type": "VarChar"},
@ -327,7 +352,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
expected_type = expected_config.get("type") expected_type = expected_config.get("type")
logger.debug( logger.debug(
f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}"
) )
# Convert DataType enum values to string names if needed # Convert DataType enum values to string names if needed
@ -335,7 +360,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if hasattr(existing_type, "name"): if hasattr(existing_type, "name"):
existing_type = existing_type.name existing_type = existing_type.name
logger.debug( logger.debug(
f"Converted enum to name: {original_existing_type} -> {existing_type}" f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}"
) )
elif isinstance(existing_type, int): elif isinstance(existing_type, int):
# Map common Milvus internal type codes to type names for backward compatibility # Map common Milvus internal type codes to type names for backward compatibility
@ -346,7 +371,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
9: "Double", 9: "Double",
} }
mapped_type = type_mapping.get(existing_type, str(existing_type)) mapped_type = type_mapping.get(existing_type, str(existing_type))
logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}") logger.debug(
f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}"
)
existing_type = mapped_type existing_type = mapped_type
# Normalize type names for comparison # Normalize type names for comparison
@ -367,18 +394,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if original_existing != existing_type or original_expected != expected_type: if original_existing != existing_type or original_expected != expected_type:
logger.debug( logger.debug(
f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}"
) )
# Basic type compatibility check # Basic type compatibility check
type_compatible = existing_type == expected_type type_compatible = existing_type == expected_type
logger.debug( logger.debug(
f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}"
) )
if not type_compatible: if not type_compatible:
logger.warning( logger.warning(
f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}"
) )
return False return False
@ -391,23 +418,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
or existing_field.get("primary_key", False) or existing_field.get("primary_key", False)
) )
logger.debug( logger.debug(
f"Primary key check for '{field_name}': expected=True, actual={is_primary}" f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}"
)
logger.debug(
f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}"
) )
logger.debug(f"Raw field data for '{field_name}': {existing_field}")
# For ID field, be more lenient - if it's the ID field, assume it should be primary # For ID field, be more lenient - if it's the ID field, assume it should be primary
if field_name == "id" and not is_primary: if field_name == "id" and not is_primary:
logger.info( logger.info(
f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible"
) )
# Don't fail for ID field primary key mismatch # Don't fail for ID field primary key mismatch
elif not is_primary: elif not is_primary:
logger.warning( logger.warning(
f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary"
) )
return False return False
logger.debug(f"Field '{field_name}' is compatible") logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible")
return True return True
def _check_vector_dimension(self, collection_info: dict): def _check_vector_dimension(self, collection_info: dict):
@ -434,18 +463,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if existing_dimension != current_dimension: if existing_dimension != current_dimension:
raise ValueError( raise ValueError(
f"Vector dimension mismatch for collection '{self.namespace}': " f"Vector dimension mismatch for collection '{self.final_namespace}': "
f"existing={existing_dimension}, current={current_dimension}" f"existing={existing_dimension}, current={current_dimension}"
) )
logger.debug(f"Vector dimension check passed: {current_dimension}") logger.debug(
f"[{self.workspace}] Vector dimension check passed: {current_dimension}"
)
return return
# If no vector field found, this might be an old collection created with simple schema # If no vector field found, this might be an old collection created with simple schema
logger.warning( logger.warning(
f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema."
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance."
) )
logger.warning("Consider recreating the collection for optimal performance.")
return return
def _check_schema_compatibility(self, collection_info: dict): def _check_schema_compatibility(self, collection_info: dict):
@ -461,12 +494,14 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if not has_vector_field: if not has_vector_field:
logger.warning( logger.warning(
f"Collection {self.namespace} appears to be created with old simple schema (no vector field)" f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)"
) )
logger.warning( logger.warning(
"This collection will work but may have suboptimal performance" f"[{self.workspace}] This collection will work but may have suboptimal performance"
)
logger.warning(
f"[{self.workspace}] Consider recreating the collection for optimal performance"
) )
logger.warning("Consider recreating the collection for optimal performance")
return return
# For collections with vector field, check basic compatibility # For collections with vector field, check basic compatibility
@ -486,7 +521,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if incompatible_fields: if incompatible_fields:
raise ValueError( raise ValueError(
f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}" f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}"
) )
# Get all expected fields for informational purposes # Get all expected fields for informational purposes
@ -497,18 +532,20 @@ class MilvusVectorDBStorage(BaseVectorStorage):
if missing_fields: if missing_fields:
logger.info( logger.info(
f"Collection {self.namespace} missing optional fields: {missing_fields}" f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}"
) )
logger.info( logger.info(
"These fields would be available in a newly created collection for better performance" "These fields would be available in a newly created collection for better performance"
) )
logger.debug(f"Schema compatibility check passed for {self.namespace}") logger.debug(
f"[{self.workspace}] Schema compatibility check passed for {self.namespace}"
)
def _validate_collection_compatibility(self): def _validate_collection_compatibility(self):
"""Validate existing collection's dimension and schema compatibility""" """Validate existing collection's dimension and schema compatibility"""
try: try:
collection_info = self._client.describe_collection(self.namespace) collection_info = self._client.describe_collection(self.final_namespace)
# 1. Check vector dimension # 1. Check vector dimension
self._check_vector_dimension(collection_info) self._check_vector_dimension(collection_info)
@ -517,12 +554,12 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._check_schema_compatibility(collection_info) self._check_schema_compatibility(collection_info)
logger.info( logger.info(
f"VectorDB Collection '{self.namespace}' compatibility validation passed" f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Collection compatibility validation failed for {self.namespace}: {e}" f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}"
) )
raise raise
@ -530,17 +567,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
"""Ensure the collection is loaded into memory for search operations""" """Ensure the collection is loaded into memory for search operations"""
try: try:
# Check if collection exists first # Check if collection exists first
if not self._client.has_collection(self.namespace): if not self._client.has_collection(self.final_namespace):
logger.error(f"Collection {self.namespace} does not exist") logger.error(
raise ValueError(f"Collection {self.namespace} does not exist") f"[{self.workspace}] Collection {self.namespace} does not exist"
)
raise ValueError(f"Collection {self.final_namespace} does not exist")
# Load the collection if it's not already loaded # Load the collection if it's not already loaded
# In Milvus, collections need to be loaded before they can be searched # In Milvus, collections need to be loaded before they can be searched
self._client.load_collection(self.namespace) self._client.load_collection(self.final_namespace)
# logger.debug(f"Collection {self.namespace} loaded successfully") # logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully")
except Exception as e: except Exception as e:
logger.error(f"Failed to load collection {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Failed to load collection {self.namespace}: {e}"
)
raise raise
def _create_collection_if_not_exist(self): def _create_collection_if_not_exist(self):
@ -550,41 +591,45 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# First, list all collections to see what actually exists # First, list all collections to see what actually exists
try: try:
all_collections = self._client.list_collections() all_collections = self._client.list_collections()
logger.debug(f"All collections in database: {all_collections}") logger.debug(
f"[{self.workspace}] All collections in database: {all_collections}"
)
except Exception as list_error: except Exception as list_error:
logger.warning(f"Could not list collections: {list_error}") logger.warning(
f"[{self.workspace}] Could not list collections: {list_error}"
)
all_collections = [] all_collections = []
# Check if our specific collection exists # Check if our specific collection exists
collection_exists = self._client.has_collection(self.namespace) collection_exists = self._client.has_collection(self.final_namespace)
logger.info( logger.info(
f"VectorDB collection '{self.namespace}' exists check: {collection_exists}" f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}"
) )
if collection_exists: if collection_exists:
# Double-check by trying to describe the collection # Double-check by trying to describe the collection
try: try:
self._client.describe_collection(self.namespace) self._client.describe_collection(self.final_namespace)
self._validate_collection_compatibility() self._validate_collection_compatibility()
# Ensure the collection is loaded after validation # Ensure the collection is loaded after validation
self._ensure_collection_loaded() self._ensure_collection_loaded()
return return
except Exception as describe_error: except Exception as describe_error:
logger.warning( logger.warning(
f"Collection '{self.namespace}' exists but cannot be described: {describe_error}" f"[{self.workspace}] Collection '{self.namespace}' exists but cannot be described: {describe_error}"
) )
logger.info( logger.info(
"Treating as if collection doesn't exist and creating new one..." f"[{self.workspace}] Treating as if collection doesn't exist and creating new one..."
) )
# Fall through to creation logic # Fall through to creation logic
# Collection doesn't exist, create new collection # Collection doesn't exist, create new collection
logger.info(f"Creating new collection: {self.namespace}") logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}")
schema = self._create_schema_for_namespace() schema = self._create_schema_for_namespace()
# Create collection with schema only first # Create collection with schema only first
self._client.create_collection( self._client.create_collection(
collection_name=self.namespace, schema=schema collection_name=self.final_namespace, schema=schema
) )
# Then create indexes # Then create indexes
@ -593,43 +638,49 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Load the newly created collection # Load the newly created collection
self._ensure_collection_loaded() self._ensure_collection_loaded()
logger.info(f"Successfully created Milvus collection: {self.namespace}") logger.info(
f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error in _create_collection_if_not_exist for {self.namespace}: {e}" f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}"
) )
# If there's any error, try to force create the collection # If there's any error, try to force create the collection
logger.info(f"Attempting to force create collection {self.namespace}...") logger.info(
f"[{self.workspace}] Attempting to force create collection {self.namespace}..."
)
try: try:
# Try to drop the collection first if it exists in a bad state # Try to drop the collection first if it exists in a bad state
try: try:
if self._client.has_collection(self.namespace): if self._client.has_collection(self.final_namespace):
logger.info( logger.info(
f"Dropping potentially corrupted collection {self.namespace}" f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}"
) )
self._client.drop_collection(self.namespace) self._client.drop_collection(self.final_namespace)
except Exception as drop_error: except Exception as drop_error:
logger.warning( logger.warning(
f"Could not drop collection {self.namespace}: {drop_error}" f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}"
) )
# Create fresh collection # Create fresh collection
schema = self._create_schema_for_namespace() schema = self._create_schema_for_namespace()
self._client.create_collection( self._client.create_collection(
collection_name=self.namespace, schema=schema collection_name=self.final_namespace, schema=schema
) )
self._create_indexes_after_collection() self._create_indexes_after_collection()
# Load the newly created collection # Load the newly created collection
self._ensure_collection_loaded() self._ensure_collection_loaded()
logger.info(f"Successfully force-created collection {self.namespace}") logger.info(
f"[{self.workspace}] Successfully force-created collection {self.namespace}"
)
except Exception as create_error: except Exception as create_error:
logger.error( logger.error(
f"Failed to force-create collection {self.namespace}: {create_error}" f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}"
) )
raise raise
@ -651,11 +702,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self.workspace = "_"
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -699,7 +757,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._create_collection_if_not_exist() self._create_collection_if_not_exist()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -730,7 +788,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["vector"] = embeddings[i] d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data) results = self._client.upsert(
collection_name=self.final_namespace, data=list_data
)
return results return results
async def query( async def query(
@ -747,7 +807,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
output_fields = list(self.meta_fields) output_fields = list(self.meta_fields)
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.final_namespace,
data=embedding, data=embedding,
limit=top_k, limit=top_k,
output_fields=output_fields, output_fields=output_fields,
@ -780,21 +840,25 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Compute entity ID from name # Compute entity ID from name
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
) )
# Delete the entity from Milvus collection # Delete the entity from Milvus collection
result = self._client.delete( result = self._client.delete(
collection_name=self.namespace, pks=[entity_id] collection_name=self.final_namespace, pks=[entity_id]
) )
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else: else:
logger.debug(f"Entity {entity_name} not found in storage") logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
@ -811,31 +875,35 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Find all relations involving this entity # Find all relations involving this entity
results = self._client.query( results = self._client.query(
collection_name=self.namespace, filter=expr, output_fields=["id"] collection_name=self.final_namespace, filter=expr, output_fields=["id"]
) )
if not results or len(results) == 0: if not results or len(results) == 0:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
return return
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [item["id"] for item in results] relation_ids = [item["id"] for item in results]
logger.debug( logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}" f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}"
) )
# Delete the relations # Delete the relations
if relation_ids: if relation_ids:
delete_result = self._client.delete( delete_result = self._client.delete(
collection_name=self.namespace, pks=relation_ids collection_name=self.final_namespace, pks=relation_ids
) )
logger.debug( logger.debug(
f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" f"[{self.workspace}] Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs """Delete vectors with specified IDs
@ -848,17 +916,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
self._ensure_collection_loaded() self._ensure_collection_loaded()
# Delete vectors by IDs # Delete vectors by IDs
result = self._client.delete(collection_name=self.namespace, pks=ids) result = self._client.delete(collection_name=self.final_namespace, pks=ids)
if result and result.get("delete_count", 0) > 0: if result and result.get("delete_count", 0) > 0:
logger.debug( logger.debug(
f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" f"[{self.workspace}] Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}"
) )
else: else:
logger.debug(f"No vectors were deleted from {self.namespace}") logger.debug(
f"[{self.workspace}] No vectors were deleted from {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID """Get vector data by its ID
@ -878,7 +950,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Query Milvus for a specific ID # Query Milvus for a specific ID
result = self._client.query( result = self._client.query(
collection_name=self.namespace, collection_name=self.final_namespace,
filter=f'id == "{id}"', filter=f'id == "{id}"',
output_fields=output_fields, output_fields=output_fields,
) )
@ -888,7 +960,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
return result[0] return result[0]
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -916,14 +990,16 @@ class MilvusVectorDBStorage(BaseVectorStorage):
# Query Milvus with the filter # Query Milvus with the filter
result = self._client.query( result = self._client.query(
collection_name=self.namespace, collection_name=self.final_namespace,
filter=filter_expr, filter=filter_expr,
output_fields=output_fields, output_fields=output_fields,
) )
return result or [] return result or []
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return [] return []
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -938,16 +1014,18 @@ class MilvusVectorDBStorage(BaseVectorStorage):
""" """
try: try:
# Drop the collection and recreate it # Drop the collection and recreate it
if self._client.has_collection(self.namespace): if self._client.has_collection(self.final_namespace):
self._client.drop_collection(self.namespace) self._client.drop_collection(self.final_namespace)
# Recreate the collection # Recreate the collection
self._create_collection_if_not_exist() self._create_collection_if_not_exist()
logger.info( logger.info(
f"Process {os.getpid()} drop Milvus collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -107,19 +107,30 @@ class MongoKVStorage(BaseKVStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(
f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
)
self._collection_name = self.namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}") logger.debug(
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@ -167,7 +178,7 @@ class MongoKVStorage(BaseKVStorage):
return result return result
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -227,10 +238,12 @@ class MongoKVStorage(BaseKVStorage):
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.info( logger.info(
f"Deleted {result.deleted_count} documents from {self.namespace}" f"[{self.workspace}] Deleted {result.deleted_count} documents from {self.namespace}"
) )
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error deleting documents from {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error deleting documents from {self.namespace}: {e}"
)
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all documents in the collection. """Drop the storage by removing all documents in the collection.
@ -243,14 +256,16 @@ class MongoKVStorage(BaseKVStorage):
deleted_count = result.deleted_count deleted_count = result.deleted_count
logger.info( logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}") logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@ -287,13 +302,20 @@ class MongoDocStatusStorage(DocStatusStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._collection_name = self.namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
@ -306,7 +328,9 @@ class MongoDocStatusStorage(DocStatusStorage):
# Create pagination indexes for better query performance # Create pagination indexes for better query performance
await self.create_pagination_indexes_if_not_exists() await self.create_pagination_indexes_if_not_exists()
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") logger.debug(
f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}"
)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@ -327,7 +351,7 @@ class MongoDocStatusStorage(DocStatusStorage):
return data - existing_ids return data - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
update_tasks: list[Any] = [] update_tasks: list[Any] = []
@ -376,7 +400,9 @@ class MongoDocStatusStorage(DocStatusStorage):
data["error_msg"] = None data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data) processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}") logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue continue
return processed_result return processed_result
@ -405,7 +431,9 @@ class MongoDocStatusStorage(DocStatusStorage):
data["error_msg"] = None data["error_msg"] = None
processed_result[doc["_id"]] = DocProcessingStatus(**data) processed_result[doc["_id"]] = DocProcessingStatus(**data)
except KeyError as e: except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}") logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue continue
return processed_result return processed_result
@ -424,14 +452,16 @@ class MongoDocStatusStorage(DocStatusStorage):
deleted_count = result.deleted_count deleted_count = result.deleted_count
logger.info( logger.info(
f"Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error dropping doc status {self._collection_name}: {e}") logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
@ -450,16 +480,16 @@ class MongoDocStatusStorage(DocStatusStorage):
if not track_id_index_exists: if not track_id_index_exists:
await self._data.create_index("track_id") await self._data.create_index("track_id")
logger.info( logger.info(
f"Created track_id index for collection {self._collection_name}" f"[{self.workspace}] Created track_id index for collection {self._collection_name}"
) )
else: else:
logger.debug( logger.debug(
f"track_id index already exists for collection {self._collection_name}" f"[{self.workspace}] track_id index already exists for collection {self._collection_name}"
) )
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"Error creating track_id index for {self._collection_name}: {e}" f"[{self.workspace}] Error creating track_id index for {self._collection_name}: {e}"
) )
async def create_pagination_indexes_if_not_exists(self): async def create_pagination_indexes_if_not_exists(self):
@ -492,16 +522,16 @@ class MongoDocStatusStorage(DocStatusStorage):
if index_name not in existing_index_names: if index_name not in existing_index_names:
await self._data.create_index(index_info["keys"], name=index_name) await self._data.create_index(index_info["keys"], name=index_name)
logger.info( logger.info(
f"Created pagination index '{index_name}' for collection {self._collection_name}" f"[{self.workspace}] Created pagination index '{index_name}' for collection {self._collection_name}"
) )
else: else:
logger.debug( logger.debug(
f"Pagination index '{index_name}' already exists for collection {self._collection_name}" f"[{self.workspace}] Pagination index '{index_name}' already exists for collection {self._collection_name}"
) )
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"Error creating pagination indexes for {self._collection_name}: {e}" f"[{self.workspace}] Error creating pagination indexes for {self._collection_name}: {e}"
) )
async def get_docs_paginated( async def get_docs_paginated(
@ -586,7 +616,9 @@ class MongoDocStatusStorage(DocStatusStorage):
doc_status = DocProcessingStatus(**data) doc_status = DocProcessingStatus(**data)
documents.append((doc_id, doc_status)) documents.append((doc_id, doc_status))
except KeyError as e: except KeyError as e:
logger.error(f"Missing required field for document {doc['_id']}: {e}") logger.error(
f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}"
)
continue continue
return documents, total_count return documents, total_count
@ -650,13 +682,20 @@ class MongoGraphStorage(BaseGraphStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._collection_name = self.namespace self._collection_name = self.final_namespace
self._edge_collection_name = f"{self._collection_name}_edges" self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self): async def initialize(self):
@ -668,7 +707,9 @@ class MongoGraphStorage(BaseGraphStorage):
self.edge_collection = await get_or_create_collection( self.edge_collection = await get_or_create_collection(
self.db, self._edge_collection_name self.db, self._edge_collection_name
) )
logger.debug(f"Use MongoDB as KG {self._collection_name}") logger.debug(
f"[{self.workspace}] Use MongoDB as KG {self._collection_name}"
)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@ -1248,7 +1289,9 @@ class MongoGraphStorage(BaseGraphStorage):
# Verify if starting node exists # Verify if starting node exists
start_node = await self.collection.find_one({"_id": node_label}) start_node = await self.collection.find_one({"_id": node_label})
if not start_node: if not start_node:
logger.warning(f"Starting node with label {node_label} does not exist!") logger.warning(
f"[{self.workspace}] Starting node with label {node_label} does not exist!"
)
return result return result
seen_nodes.add(node_label) seen_nodes.add(node_label)
@ -1407,14 +1450,14 @@ class MongoGraphStorage(BaseGraphStorage):
duration = time.perf_counter() - start duration = time.perf_counter() - start
logger.info( logger.info(
f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" f"[{self.workspace}] Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
) )
except PyMongoError as e: except PyMongoError as e:
# Handle memory limit errors specifically # Handle memory limit errors specifically
if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower(): if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower():
logger.warning( logger.warning(
f"MongoDB memory limit exceeded, falling back to simple query: {str(e)}" f"[{self.workspace}] MongoDB memory limit exceeded, falling back to simple query: {str(e)}"
) )
# Fallback to a simple query without complex aggregation # Fallback to a simple query without complex aggregation
try: try:
@ -1425,12 +1468,14 @@ class MongoGraphStorage(BaseGraphStorage):
) )
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
f"Fallback query completed | Node count: {len(result.nodes)}" f"[{self.workspace}] Fallback query completed | Node count: {len(result.nodes)}"
) )
except PyMongoError as fallback_error: except PyMongoError as fallback_error:
logger.error(f"Fallback query also failed: {str(fallback_error)}") logger.error(
f"[{self.workspace}] Fallback query also failed: {str(fallback_error)}"
)
else: else:
logger.error(f"MongoDB query failed: {str(e)}") logger.error(f"[{self.workspace}] MongoDB query failed: {str(e)}")
return result return result
@ -1444,7 +1489,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args: Args:
nodes: List of node IDs to be deleted nodes: List of node IDs to be deleted
""" """
logger.info(f"Deleting {len(nodes)} nodes") logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes")
if not nodes: if not nodes:
return return
@ -1461,7 +1506,7 @@ class MongoGraphStorage(BaseGraphStorage):
# 2. Delete the node documents # 2. Delete the node documents
await self.collection.delete_many({"_id": {"$in": nodes}}) await self.collection.delete_many({"_id": {"$in": nodes}})
logger.debug(f"Successfully deleted nodes: {nodes}") logger.debug(f"[{self.workspace}] Successfully deleted nodes: {nodes}")
async def remove_edges(self, edges: list[tuple[str, str]]) -> None: async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
"""Delete multiple edges """Delete multiple edges
@ -1469,7 +1514,7 @@ class MongoGraphStorage(BaseGraphStorage):
Args: Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
logger.info(f"Deleting {len(edges)} edges") logger.info(f"[{self.workspace}] Deleting {len(edges)} edges")
if not edges: if not edges:
return return
@ -1484,7 +1529,7 @@ class MongoGraphStorage(BaseGraphStorage):
await self.edge_collection.delete_many({"$or": all_edge_pairs}) await self.edge_collection.delete_many({"$or": all_edge_pairs})
logger.debug(f"Successfully deleted edges: {edges}") logger.debug(f"[{self.workspace}] Successfully deleted edges: {edges}")
async def get_all_nodes(self) -> list[dict]: async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph. """Get all nodes in the graph.
@ -1527,13 +1572,13 @@ class MongoGraphStorage(BaseGraphStorage):
deleted_count = result.deleted_count deleted_count = result.deleted_count
logger.info( logger.info(
f"Dropped {deleted_count} documents from graph {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
) )
result = await self.edge_collection.delete_many({}) result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count edge_count = result.deleted_count
logger.info( logger.info(
f"Dropped {edge_count} edges from graph {self._edge_collection_name}" f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
) )
return { return {
@ -1541,7 +1586,9 @@ class MongoGraphStorage(BaseGraphStorage):
"message": f"{deleted_count} documents and {edge_count} edges dropped", "message": f"{deleted_count} documents and {edge_count} edges dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error dropping graph {self._collection_name}: {e}") logger.error(
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@ -1582,16 +1629,23 @@ class MongoVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
# Set index name based on workspace for backward compatibility # Set index name based on workspace for backward compatibility
if effective_workspace: if effective_workspace:
# Use collection-specific index name for workspaced collections to avoid conflicts # Use collection-specific index name for workspaced collections to avoid conflicts
self._index_name = f"vector_knn_index_{self.namespace}" self._index_name = f"vector_knn_index_{self.final_namespace}"
else: else:
# Keep original index name for backward compatibility with existing deployments # Keep original index name for backward compatibility with existing deployments
self._index_name = "vector_knn_index" self._index_name = "vector_knn_index"
@ -1603,7 +1657,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
self._collection_name = self.namespace self._collection_name = self.final_namespace
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
@ -1614,7 +1668,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
# Ensure vector index exists # Ensure vector index exists
await self.create_vector_index_if_not_exists() await self.create_vector_index_if_not_exists()
logger.debug(f"Use MongoDB as VDB {self._collection_name}") logger.debug(
f"[{self.workspace}] Use MongoDB as VDB {self._collection_name}"
)
async def finalize(self): async def finalize(self):
if self.db is not None: if self.db is not None:
@ -1629,7 +1685,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
indexes = await indexes_cursor.to_list(length=None) indexes = await indexes_cursor.to_list(length=None)
for index in indexes: for index in indexes:
if index["name"] == self._index_name: if index["name"] == self._index_name:
logger.info(f"vector index {self._index_name} already exist") logger.info(
f"[{self.workspace}] vector index {self._index_name} already exist"
)
return return
search_index_model = SearchIndexModel( search_index_model = SearchIndexModel(
@ -1648,17 +1706,19 @@ class MongoVectorDBStorage(BaseVectorStorage):
) )
await self._data.create_search_index(search_index_model) await self._data.create_search_index(search_index_model)
logger.info(f"Vector index {self._index_name} created successfully.") logger.info(
f"[{self.workspace}] Vector index {self._index_name} created successfully."
)
except PyMongoError as e: except PyMongoError as e:
error_msg = f"Error creating vector index {self._index_name}: {e}" error_msg = f"[{self.workspace}] Error creating vector index {self._index_name}: {e}"
logger.error(error_msg) logger.error(error_msg)
raise SystemExit( raise SystemExit(
f"Failed to create MongoDB vector index. Program cannot continue. {error_msg}" f"Failed to create MongoDB vector index. Program cannot continue. {error_msg}"
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -1747,7 +1807,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
Args: Args:
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}") logger.debug(
f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}"
)
if not ids: if not ids:
return return
@ -1758,11 +1820,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
try: try:
result = await self._data.delete_many({"_id": {"$in": ids}}) result = await self._data.delete_many({"_id": {"$in": ids}})
logger.debug( logger.debug(
f"Successfully deleted {result.deleted_count} vectors from {self.namespace}" f"[{self.workspace}] Successfully deleted {result.deleted_count} vectors from {self.namespace}"
) )
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"Error while deleting vectors from {self.namespace}: {str(e)}" f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {str(e)}"
) )
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
@ -1774,16 +1836,22 @@ class MongoVectorDBStorage(BaseVectorStorage):
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
) )
result = await self._data.delete_one({"_id": entity_id}) result = await self._data.delete_one({"_id": entity_id})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else: else:
logger.debug(f"Entity {entity_name} not found in storage") logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error deleting entity {entity_name}: {str(e)}") logger.error(
f"[{self.workspace}] Error deleting entity {entity_name}: {str(e)}"
)
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
@ -1799,23 +1867,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
relations = await relations_cursor.to_list(length=None) relations = await relations_cursor.to_list(length=None)
if not relations: if not relations:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
return return
# Extract IDs of relations to delete # Extract IDs of relations to delete
relation_ids = [relation["_id"] for relation in relations] relation_ids = [relation["_id"] for relation in relations]
logger.debug( logger.debug(
f"Found {len(relation_ids)} relations for entity {entity_name}" f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}"
) )
# Delete the relations # Delete the relations
result = await self._data.delete_many({"_id": {"$in": relation_ids}}) result = await self._data.delete_many({"_id": {"$in": relation_ids}})
logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") logger.debug(
f"[{self.workspace}] Deleted {result.deleted_count} relations for {entity_name}"
)
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error deleting relations for {entity_name}: {str(e)}") logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {str(e)}"
)
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}") logger.error(
f"[{self.workspace}] Error searching by prefix in {self.namespace}: {str(e)}"
)
return [] return []
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -1838,7 +1914,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return result_dict return result_dict
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -1868,7 +1946,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return formatted_results return formatted_results
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return [] return []
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -1886,14 +1966,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
await self.create_vector_index_if_not_exists() await self.create_vector_index_if_not_exists()
logger.info( logger.info(
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated", "message": f"{deleted_count} documents dropped and vector index recreated",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error(f"Error dropping vector storage {self._collection_name}: {e}") logger.error(
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -48,23 +48,26 @@ logging.getLogger("neo4j").setLevel(logging.ERROR)
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None): def __init__(self, namespace, global_config, embedding_func, workspace=None):
# Check NEO4J_WORKSPACE environment variable and override workspace if set # Read env and override the arg if present
neo4j_workspace = os.environ.get("NEO4J_WORKSPACE") neo4j_workspace = os.environ.get("NEO4J_WORKSPACE")
if neo4j_workspace and neo4j_workspace.strip(): if neo4j_workspace and neo4j_workspace.strip():
workspace = neo4j_workspace workspace = neo4j_workspace
# Default to 'base' when both arg and env are empty
if not workspace or not str(workspace).strip():
workspace = "base"
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
workspace=workspace or "", workspace=workspace,
global_config=global_config, global_config=global_config,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self._driver = None self._driver = None
def _get_workspace_label(self) -> str: def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty""" """Return workspace label (guaranteed non-empty during initialization)"""
workspace = getattr(self, "workspace", None) return self.workspace
return workspace if workspace else "base"
async def initialize(self): async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
@ -117,6 +120,7 @@ class Neo4JStorage(BaseGraphStorage):
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
) )
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, URI,
@ -140,20 +144,26 @@ class Neo4JStorage(BaseGraphStorage):
try: try:
result = await session.run("MATCH (n) RETURN n LIMIT 0") result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}") logger.info(
f"[{self.workspace}] Connected to {database} at {URI}"
)
connected = True connected = True
except neo4jExceptions.ServiceUnavailable as e: except neo4jExceptions.ServiceUnavailable as e:
logger.error( logger.error(
f"{database} at {URI} is not available".capitalize() f"[{self.workspace}] "
+ f"{database} at {URI} is not available".capitalize()
) )
raise e raise e
except neo4jExceptions.AuthError as e: except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}") logger.error(
f"[{self.workspace}] Authentication failed for {database} at {URI}"
)
raise e raise e
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound": if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info( logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize() f"[{self.workspace}] "
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
) )
try: try:
async with self._driver.session() as session: async with self._driver.session() as session:
@ -161,7 +171,10 @@ class Neo4JStorage(BaseGraphStorage):
f"CREATE DATABASE `{database}` IF NOT EXISTS" f"CREATE DATABASE `{database}` IF NOT EXISTS"
) )
await result.consume() # Ensure result is consumed await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize()) logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} created".capitalize()
)
connected = True connected = True
except ( except (
neo4jExceptions.ClientError, neo4jExceptions.ClientError,
@ -173,10 +186,12 @@ class Neo4JStorage(BaseGraphStorage):
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None: if database is not None:
logger.warning( logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
) )
if database is None: if database is None:
logger.error(f"Failed to create {database} at {URI}") logger.error(
f"[{self.workspace}] Failed to create {database} at {URI}"
)
raise e raise e
if connected: if connected:
@ -204,7 +219,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
await result.consume() await result.consume()
logger.info( logger.info(
f"Created index for {workspace_label} nodes on entity_id in {database}" f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}"
) )
except Exception: except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version # Fallback if db.indexes() is not supported in this Neo4j version
@ -213,7 +228,9 @@ class Neo4JStorage(BaseGraphStorage):
) )
await result.consume() await result.consume()
except Exception as e: except Exception as e:
logger.warning(f"Failed to create index: {str(e)}") logger.warning(
f"[{self.workspace}] Failed to create index: {str(e)}"
)
break break
async def finalize(self): async def finalize(self):
@ -255,7 +272,9 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
return single_result["node_exists"] return single_result["node_exists"]
except Exception as e: except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@ -293,7 +312,7 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"] return single_result["edgeExists"]
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
) )
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@ -328,7 +347,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node." f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node."
) )
if records: if records:
node = records[0]["n"] node = records[0]["n"]
@ -346,7 +365,9 @@ class Neo4JStorage(BaseGraphStorage):
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error getting node for {node_id}: {str(e)}"
)
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
@ -415,18 +436,22 @@ class Neo4JStorage(BaseGraphStorage):
record = await result.single() record = await result.single()
if not record: if not record:
logger.warning(f"No node found with label '{node_id}'") logger.warning(
f"[{self.workspace}] No node found with label '{node_id}'"
)
return 0 return 0
degree = record["degree"] degree = record["degree"]
# logger.debug( # logger.debug(
# f"Neo4j query node degree for {node_id} return: {degree}" # f"[{self.workspace}] Neo4j query node degree for {node_id} return: {degree}"
# ) # )
return degree return degree
finally: finally:
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
except Exception as e: except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}"
)
raise raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
@ -459,10 +484,12 @@ class Neo4JStorage(BaseGraphStorage):
# For any node_id that did not return a record, set degree to 0. # For any node_id that did not return a record, set degree to 0.
for nid in node_ids: for nid in node_ids:
if nid not in degrees: if nid not in degrees:
logger.warning(f"No node found with label '{nid}'") logger.warning(
f"[{self.workspace}] No node found with label '{nid}'"
)
degrees[nid] = 0 degrees[nid] = 0
# logger.debug(f"Neo4j batch node degree query returned: {degrees}") # logger.debug(f"[{self.workspace}] Neo4j batch node degree query returned: {degrees}")
return degrees return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@ -546,7 +573,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(records) > 1: if len(records) > 1:
logger.warning( logger.warning(
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
) )
if records: if records:
try: try:
@ -563,7 +590,7 @@ class Neo4JStorage(BaseGraphStorage):
if key not in edge_result: if key not in edge_result:
edge_result[key] = default_value edge_result[key] = default_value
logger.warning( logger.warning(
f"Edge between {source_node_id} and {target_node_id} " f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} "
f"missing {key}, using default: {default_value}" f"missing {key}, using default: {default_value}"
) )
@ -573,7 +600,7 @@ class Neo4JStorage(BaseGraphStorage):
return edge_result return edge_result
except (KeyError, TypeError, ValueError) as e: except (KeyError, TypeError, ValueError) as e:
logger.error( logger.error(
f"Error processing edge properties between {source_node_id} " f"[{self.workspace}] Error processing edge properties between {source_node_id} "
f"and {target_node_id}: {str(e)}" f"and {target_node_id}: {str(e)}"
) )
# Return default edge properties on error # Return default edge properties on error
@ -594,7 +621,7 @@ class Neo4JStorage(BaseGraphStorage):
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
) )
raise raise
@ -701,12 +728,14 @@ class Neo4JStorage(BaseGraphStorage):
return edges return edges
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}" f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
) )
await results.consume() # Ensure results are consumed even on error await results.consume() # Ensure results are consumed even on error
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(
f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}"
)
raise raise
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
@ -856,7 +885,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
except Exception as e: except Exception as e:
logger.error(f"Error during upsert: {str(e)}") logger.error(f"[{self.workspace}] Error during upsert: {str(e)}")
raise raise
@retry( @retry(
@ -917,7 +946,7 @@ class Neo4JStorage(BaseGraphStorage):
await session.execute_write(execute_upsert) await session.execute_write(execute_upsert)
except Exception as e: except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}") logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}")
raise raise
async def get_knowledge_graph( async def get_knowledge_graph(
@ -967,7 +996,7 @@ class Neo4JStorage(BaseGraphStorage):
if count_record and count_record["total"] > max_nodes: if count_record and count_record["total"] > max_nodes:
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" f"[{self.workspace}] Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
) )
finally: finally:
if count_result: if count_result:
@ -1034,7 +1063,9 @@ class Neo4JStorage(BaseGraphStorage):
# If no record found, return empty KnowledgeGraph # If no record found, return empty KnowledgeGraph
if not full_record: if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}") logger.debug(
f"[{self.workspace}] No nodes found for entity_id: {node_label}"
)
return result return result
# If record found, check node count # If record found, check node count
@ -1043,14 +1074,14 @@ class Neo4JStorage(BaseGraphStorage):
if total_nodes <= max_nodes: if total_nodes <= max_nodes:
# If node count is within limit, use full result directly # If node count is within limit, use full result directly
logger.debug( logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)" f"[{self.workspace}] Using full result with {total_nodes} nodes (no truncation needed)"
) )
record = full_record record = full_record
else: else:
# If node count exceeds limit, set truncated flag and run limited query # If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" f"[{self.workspace}] Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
) )
# Run limited query # Run limited query
@ -1122,19 +1153,19 @@ class Neo4JStorage(BaseGraphStorage):
seen_edges.add(edge_id) seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}") logger.warning(f"[{self.workspace}] APOC plugin error: {str(e)}")
if node_label != "*": if node_label != "*":
logger.warning( logger.warning(
"Neo4j: falling back to basic Cypher recursive search..." f"[{self.workspace}] Neo4j: falling back to basic Cypher recursive search..."
) )
return await self._robust_fallback(node_label, max_depth, max_nodes) return await self._robust_fallback(node_label, max_depth, max_nodes)
else: else:
logger.warning( logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result" f"[{self.workspace}] Neo4j: APOC plugin error with wildcard query, returning empty result"
) )
return result return result
@ -1193,7 +1224,7 @@ class Neo4JStorage(BaseGraphStorage):
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})" f"[{self.workspace}] Skipping node at depth {current_depth} (max_depth: {max_depth})"
) )
continue continue
@ -1210,7 +1241,7 @@ class Neo4JStorage(BaseGraphStorage):
if len(visited_nodes) >= max_nodes: if len(visited_nodes) >= max_nodes:
result.is_truncated = True result.is_truncated = True
logger.info( logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" f"[{self.workspace}] Graph truncated: breadth-first search limited to: {max_nodes} nodes"
) )
break break
@ -1281,20 +1312,20 @@ class Neo4JStorage(BaseGraphStorage):
# At max depth, we've already added the edge but we don't add the node # At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result # This prevents adding nodes beyond max_depth to the result
logger.debug( logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included" f"[{self.workspace}] Node {target_id} beyond max depth {max_depth}, edge added but node not included"
) )
else: else:
# If target node already exists in result, we don't need to add it again # If target node already exists in result, we don't need to add it again
logger.debug( logger.debug(
f"Node {target_id} already visited, edge added but node not queued" f"[{self.workspace}] Node {target_id} already visited, edge added but node not queued"
) )
else: else:
logger.warning( logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node" f"[{self.workspace}] Skipping edge {edge_id} due to missing entity_id on target node"
) )
logger.info( logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"[{self.workspace}] BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
return result return result
@ -1358,14 +1389,14 @@ class Neo4JStorage(BaseGraphStorage):
DETACH DELETE n DETACH DELETE n
""" """
result = await tx.run(query, entity_id=node_id) result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label '{node_id}'") logger.debug(f"[{self.workspace}] Deleted node with label '{node_id}'")
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete) await session.execute_write(_do_delete)
except Exception as e: except Exception as e:
logger.error(f"Error during node deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}")
raise raise
@retry( @retry(
@ -1424,14 +1455,16 @@ class Neo4JStorage(BaseGraphStorage):
result = await tx.run( result = await tx.run(
query, source_entity_id=source, target_entity_id=target query, source_entity_id=source, target_entity_id=target
) )
logger.debug(f"Deleted edge from '{source}' to '{target}'") logger.debug(
f"[{self.workspace}] Deleted edge from '{source}' to '{target}'"
)
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge) await session.execute_write(_do_delete_edge)
except Exception as e: except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}") logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}")
raise raise
async def get_all_nodes(self) -> list[dict]: async def get_all_nodes(self) -> list[dict]:
@ -1501,15 +1534,15 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query) result = await session.run(query)
await result.consume() # Ensure result is fully consumed await result.consume() # Ensure result is fully consumed
logger.info( # logger.debug(
f"Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
) # )
return { return {
"status": "success", "status": "success",
"message": f"workspace '{workspace_label}' data dropped", "message": f"workspace '{workspace_label}' data dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -88,11 +88,18 @@ class QdrantVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@ -113,14 +120,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(
self._client, self._client,
self.namespace, self.final_namespace,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
), ),
) )
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
@ -158,7 +165,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
results = self._client.upsert( results = self._client.upsert(
collection_name=self.namespace, points=list_points, wait=True collection_name=self.final_namespace, points=list_points, wait=True
) )
return results return results
@ -169,14 +176,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
[query], _priority=5 [query], _priority=5
) # higher priority for query ) # higher priority for query
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.final_namespace,
query_vector=embedding[0], query_vector=embedding[0],
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
score_threshold=self.cosine_better_than_threshold, score_threshold=self.cosine_better_than_threshold,
) )
logger.debug(f"query result: {results}") logger.debug(f"[{self.workspace}] query result: {results}")
return [ return [
{ {
@ -202,17 +209,19 @@ class QdrantVectorDBStorage(BaseVectorStorage):
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
# Delete points from the collection # Delete points from the collection
self._client.delete( self._client.delete(
collection_name=self.namespace, collection_name=self.final_namespace,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=qdrant_ids, points=qdrant_ids,
), ),
wait=True, wait=True,
) )
logger.debug( logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}" f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name """Delete an entity by name
@ -224,20 +233,22 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Generate the entity ID # Generate the entity ID
entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-")
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
) )
# Delete the entity point from the collection # Delete the entity point from the collection
self._client.delete( self._client.delete(
collection_name=self.namespace, collection_name=self.final_namespace,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=[entity_id], points=[entity_id],
), ),
wait=True, wait=True,
) )
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity """Delete all relations associated with an entity
@ -248,7 +259,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
try: try:
# Find relations where the entity is either source or target # Find relations where the entity is either source or target
results = self._client.scroll( results = self._client.scroll(
collection_name=self.namespace, collection_name=self.final_namespace,
scroll_filter=models.Filter( scroll_filter=models.Filter(
should=[ should=[
models.FieldCondition( models.FieldCondition(
@ -270,19 +281,23 @@ class QdrantVectorDBStorage(BaseVectorStorage):
if ids_to_delete: if ids_to_delete:
# Delete the relations # Delete the relations
self._client.delete( self._client.delete(
collection_name=self.namespace, collection_name=self.final_namespace,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(
points=ids_to_delete, points=ids_to_delete,
), ),
wait=True, wait=True,
) )
logger.debug( logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}" f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID """Get vector data by its ID
@ -299,7 +314,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Retrieve the point by ID # Retrieve the point by ID
result = self._client.retrieve( result = self._client.retrieve(
collection_name=self.namespace, collection_name=self.final_namespace,
ids=[qdrant_id], ids=[qdrant_id],
with_payload=True, with_payload=True,
) )
@ -314,7 +329,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return payload return payload
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}"
)
return None return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@ -335,7 +352,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Retrieve the points by IDs # Retrieve the points by IDs
results = self._client.retrieve( results = self._client.retrieve(
collection_name=self.namespace, collection_name=self.final_namespace,
ids=qdrant_ids, ids=qdrant_ids,
with_payload=True, with_payload=True,
) )
@ -350,7 +367,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
return payloads return payloads
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
)
return [] return []
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -365,13 +384,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
""" """
try: try:
# Delete the collection and recreate it # Delete the collection and recreate it
if self._client.collection_exists(self.namespace): if self._client.collection_exists(self.final_namespace):
self._client.delete_collection(self.namespace) self._client.delete_collection(self.final_namespace)
# Recreate the collection # Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(
self._client, self._client,
self.namespace, self.final_namespace,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
@ -379,9 +398,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
logger.info( logger.info(
f"Process {os.getpid()} drop Qdrant collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -141,11 +141,18 @@ class RedisKVStorage(BaseKVStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"Final namespace with workspace prefix: '{self.final_namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._redis_url = os.environ.get( self._redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
@ -159,13 +166,15 @@ class RedisKVStorage(BaseKVStorage):
self._pool = RedisConnectionManager.get_pool(self._redis_url) self._pool = RedisConnectionManager.get_pool(self._redis_url)
self._redis = Redis(connection_pool=self._pool) self._redis = Redis(connection_pool=self._pool)
logger.info( logger.info(
f"Initialized Redis KV storage for {self.namespace} using shared connection pool" f"[{self.workspace}] Initialized Redis KV storage for {self.namespace} using shared connection pool"
) )
except Exception as e: except Exception as e:
# Clean up on initialization failure # Clean up on initialization failure
if self._redis_url: if self._redis_url:
RedisConnectionManager.release_pool(self._redis_url) RedisConnectionManager.release_pool(self._redis_url)
logger.error(f"Failed to initialize Redis KV storage: {e}") logger.error(
f"[{self.workspace}] Failed to initialize Redis KV storage: {e}"
)
raise raise
async def initialize(self): async def initialize(self):
@ -177,10 +186,12 @@ class RedisKVStorage(BaseKVStorage):
try: try:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
await redis.ping() await redis.ping()
logger.info(f"Connected to Redis for namespace {self.namespace}") logger.info(
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
)
self._initialized = True self._initialized = True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Redis: {e}") logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
# Clean up on connection failure # Clean up on connection failure
await self.close() await self.close()
raise raise
@ -190,7 +201,9 @@ class RedisKVStorage(BaseKVStorage):
try: try:
await self._migrate_legacy_cache_structure() await self._migrate_legacy_cache_structure()
except Exception as e: except Exception as e:
logger.error(f"Failed to migrate legacy cache structure: {e}") logger.error(
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
)
# Don't fail initialization for migration errors, just log them # Don't fail initialization for migration errors, just log them
@asynccontextmanager @asynccontextmanager
@ -203,14 +216,18 @@ class RedisKVStorage(BaseKVStorage):
# Use the existing Redis instance with shared pool # Use the existing Redis instance with shared pool
yield self._redis yield self._redis
except ConnectionError as e: except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Redis connection error in {self.namespace}: {e}"
)
raise raise
except RedisError as e: except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Redis operation error in {self.namespace}: {e}"
)
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Unexpected error in Redis operation for {self.namespace}: {e}" f"[{self.workspace}] Unexpected error in Redis operation for {self.namespace}: {e}"
) )
raise raise
@ -219,9 +236,11 @@ class RedisKVStorage(BaseKVStorage):
if hasattr(self, "_redis") and self._redis: if hasattr(self, "_redis") and self._redis:
try: try:
await self._redis.close() await self._redis.close()
logger.debug(f"Closed Redis connection for {self.namespace}") logger.debug(
f"[{self.workspace}] Closed Redis connection for {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error closing Redis connection: {e}") logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
finally: finally:
self._redis = None self._redis = None
@ -230,7 +249,7 @@ class RedisKVStorage(BaseKVStorage):
RedisConnectionManager.release_pool(self._redis_url) RedisConnectionManager.release_pool(self._redis_url)
self._pool = None self._pool = None
logger.debug( logger.debug(
f"Released Redis connection pool reference for {self.namespace}" f"[{self.workspace}] Released Redis connection pool reference for {self.namespace}"
) )
async def __aenter__(self): async def __aenter__(self):
@ -245,7 +264,7 @@ class RedisKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
data = await redis.get(f"{self.namespace}:{id}") data = await redis.get(f"{self.final_namespace}:{id}")
if data: if data:
result = json.loads(data) result = json.loads(data)
# Ensure time fields are present, provide default values for old data # Ensure time fields are present, provide default values for old data
@ -254,7 +273,7 @@ class RedisKVStorage(BaseKVStorage):
return result return result
return None return None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}") logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
return None return None
@redis_retry @redis_retry
@ -263,7 +282,7 @@ class RedisKVStorage(BaseKVStorage):
try: try:
pipe = redis.pipeline() pipe = redis.pipeline()
for id in ids: for id in ids:
pipe.get(f"{self.namespace}:{id}") pipe.get(f"{self.final_namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
processed_results = [] processed_results = []
@ -279,7 +298,7 @@ class RedisKVStorage(BaseKVStorage):
return processed_results return processed_results
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}") logger.error(f"[{self.workspace}] JSON decode error in batch get: {e}")
return [None] * len(ids) return [None] * len(ids)
async def get_all(self) -> dict[str, Any]: async def get_all(self) -> dict[str, Any]:
@ -291,7 +310,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Get all keys for this namespace # Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*") keys = await redis.keys(f"{self.final_namespace}:*")
if not keys: if not keys:
return {} return {}
@ -315,12 +334,16 @@ class RedisKVStorage(BaseKVStorage):
data.setdefault("update_time", 0) data.setdefault("update_time", 0)
result[key_id] = data result[key_id] = data
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error for key {key}: {e}") logger.error(
f"[{self.workspace}] JSON decode error for key {key}: {e}"
)
continue continue
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error getting all data from Redis: {e}") logger.error(
f"[{self.workspace}] Error getting all data from Redis: {e}"
)
return {} return {}
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
@ -328,7 +351,7 @@ class RedisKVStorage(BaseKVStorage):
pipe = redis.pipeline() pipe = redis.pipeline()
keys_list = list(keys) # Convert set to list for indexing keys_list = list(keys) # Convert set to list for indexing
for key in keys_list: for key in keys_list:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.final_namespace}:{key}")
results = await pipe.execute() results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
@ -348,13 +371,13 @@ class RedisKVStorage(BaseKVStorage):
# Check which keys already exist to determine create vs update # Check which keys already exist to determine create vs update
pipe = redis.pipeline() pipe = redis.pipeline()
for k in data.keys(): for k in data.keys():
pipe.exists(f"{self.namespace}:{k}") pipe.exists(f"{self.final_namespace}:{k}")
exists_results = await pipe.execute() exists_results = await pipe.execute()
# Add timestamps to data # Add timestamps to data
for i, (k, v) in enumerate(data.items()): for i, (k, v) in enumerate(data.items()):
# For text_chunks namespace, ensure llm_cache_list field exists # For text_chunks namespace, ensure llm_cache_list field exists
if "text_chunks" in self.namespace: if self.namespace.endswith("text_chunks"):
if "llm_cache_list" not in v: if "llm_cache_list" not in v:
v["llm_cache_list"] = [] v["llm_cache_list"] = []
@ -370,11 +393,11 @@ class RedisKVStorage(BaseKVStorage):
# Store the data # Store the data
pipe = redis.pipeline() pipe = redis.pipeline()
for k, v in data.items(): for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
await pipe.execute() await pipe.execute()
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error during upsert: {e}") logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
raise raise
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@ -389,12 +412,12 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
pipe = redis.pipeline() pipe = redis.pipeline()
for id in ids: for id in ids:
pipe.delete(f"{self.namespace}:{id}") pipe.delete(f"{self.final_namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info( logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" f"[{self.workspace}] Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
) )
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
@ -406,7 +429,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Use SCAN to find all keys with the namespace prefix # Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*" pattern = f"{self.final_namespace}:*"
cursor = 0 cursor = 0
deleted_count = 0 deleted_count = 0
@ -423,14 +446,18 @@ class RedisKVStorage(BaseKVStorage):
if cursor == 0: if cursor == 0:
break break
logger.info(f"Dropped {deleted_count} keys from {self.namespace}") logger.info(
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
)
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} keys dropped", "message": f"{deleted_count} keys dropped",
} }
except Exception as e: except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self): async def _migrate_legacy_cache_structure(self):
@ -445,7 +472,7 @@ class RedisKVStorage(BaseKVStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
# Get all keys for this namespace # Get all keys for this namespace
keys = await redis.keys(f"{self.namespace}:*") keys = await redis.keys(f"{self.final_namespace}:*")
if not keys: if not keys:
return return
@ -480,7 +507,7 @@ class RedisKVStorage(BaseKVStorage):
# If we found any flattened keys, assume migration is already done # If we found any flattened keys, assume migration is already done
if has_flattened_keys: if has_flattened_keys:
logger.debug( logger.debug(
f"Found flattened cache keys in {self.namespace}, skipping migration" f"[{self.workspace}] Found flattened cache keys in {self.namespace}, skipping migration"
) )
return return
@ -499,7 +526,7 @@ class RedisKVStorage(BaseKVStorage):
for cache_hash, cache_entry in nested_data.items(): for cache_hash, cache_entry in nested_data.items():
cache_type = cache_entry.get("cache_type", "extract") cache_type = cache_entry.get("cache_type", "extract")
flattened_key = generate_cache_key(mode, cache_type, cache_hash) flattened_key = generate_cache_key(mode, cache_type, cache_hash)
full_key = f"{self.namespace}:{flattened_key}" full_key = f"{self.final_namespace}:{flattened_key}"
pipe.set(full_key, json.dumps(cache_entry)) pipe.set(full_key, json.dumps(cache_entry))
migration_count += 1 migration_count += 1
@ -507,7 +534,7 @@ class RedisKVStorage(BaseKVStorage):
if migration_count > 0: if migration_count > 0:
logger.info( logger.info(
f"Migrated {migration_count} legacy cache entries to flattened structure in Redis" f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure in Redis"
) )
@ -534,11 +561,20 @@ class RedisDocStatusStorage(DocStatusStorage):
f"Using passed workspace parameter: '{effective_workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Build namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic
if effective_workspace: if effective_workspace:
self.namespace = f"{effective_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") logger.debug(
# When workspace is empty, keep the original namespace unchanged f"[{self.workspace}] Final namespace with workspace prefix: '{self.namespace}'"
)
else:
# When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace
self.workspace = "_"
logger.debug(
f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'"
)
self._redis_url = os.environ.get( self._redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
@ -552,13 +588,15 @@ class RedisDocStatusStorage(DocStatusStorage):
self._pool = RedisConnectionManager.get_pool(self._redis_url) self._pool = RedisConnectionManager.get_pool(self._redis_url)
self._redis = Redis(connection_pool=self._pool) self._redis = Redis(connection_pool=self._pool)
logger.info( logger.info(
f"Initialized Redis doc status storage for {self.namespace} using shared connection pool" f"[{self.workspace}] Initialized Redis doc status storage for {self.namespace} using shared connection pool"
) )
except Exception as e: except Exception as e:
# Clean up on initialization failure # Clean up on initialization failure
if self._redis_url: if self._redis_url:
RedisConnectionManager.release_pool(self._redis_url) RedisConnectionManager.release_pool(self._redis_url)
logger.error(f"Failed to initialize Redis doc status storage: {e}") logger.error(
f"[{self.workspace}] Failed to initialize Redis doc status storage: {e}"
)
raise raise
async def initialize(self): async def initialize(self):
@ -570,11 +608,13 @@ class RedisDocStatusStorage(DocStatusStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
await redis.ping() await redis.ping()
logger.info( logger.info(
f"Connected to Redis for doc status namespace {self.namespace}" f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
) )
self._initialized = True self._initialized = True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Redis for doc status: {e}") logger.error(
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
)
# Clean up on connection failure # Clean up on connection failure
await self.close() await self.close()
raise raise
@ -589,14 +629,18 @@ class RedisDocStatusStorage(DocStatusStorage):
# Use the existing Redis instance with shared pool # Use the existing Redis instance with shared pool
yield self._redis yield self._redis
except ConnectionError as e: except ConnectionError as e:
logger.error(f"Redis connection error in doc status {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Redis connection error in doc status {self.namespace}: {e}"
)
raise raise
except RedisError as e: except RedisError as e:
logger.error(f"Redis operation error in doc status {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Redis operation error in doc status {self.namespace}: {e}"
)
raise raise
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Unexpected error in Redis doc status operation for {self.namespace}: {e}" f"[{self.workspace}] Unexpected error in Redis doc status operation for {self.namespace}: {e}"
) )
raise raise
@ -605,9 +649,11 @@ class RedisDocStatusStorage(DocStatusStorage):
if hasattr(self, "_redis") and self._redis: if hasattr(self, "_redis") and self._redis:
try: try:
await self._redis.close() await self._redis.close()
logger.debug(f"Closed Redis connection for doc status {self.namespace}") logger.debug(
f"[{self.workspace}] Closed Redis connection for doc status {self.namespace}"
)
except Exception as e: except Exception as e:
logger.error(f"Error closing Redis connection: {e}") logger.error(f"[{self.workspace}] Error closing Redis connection: {e}")
finally: finally:
self._redis = None self._redis = None
@ -616,7 +662,7 @@ class RedisDocStatusStorage(DocStatusStorage):
RedisConnectionManager.release_pool(self._redis_url) RedisConnectionManager.release_pool(self._redis_url)
self._pool = None self._pool = None
logger.debug( logger.debug(
f"Released Redis connection pool reference for doc status {self.namespace}" f"[{self.workspace}] Released Redis connection pool reference for doc status {self.namespace}"
) )
async def __aenter__(self): async def __aenter__(self):
@ -633,7 +679,7 @@ class RedisDocStatusStorage(DocStatusStorage):
pipe = redis.pipeline() pipe = redis.pipeline()
keys_list = list(keys) keys_list = list(keys)
for key in keys_list: for key in keys_list:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.final_namespace}:{key}")
results = await pipe.execute() results = await pipe.execute()
existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists}
@ -645,7 +691,7 @@ class RedisDocStatusStorage(DocStatusStorage):
try: try:
pipe = redis.pipeline() pipe = redis.pipeline()
for id in ids: for id in ids:
pipe.get(f"{self.namespace}:{id}") pipe.get(f"{self.final_namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
for result_data in results: for result_data in results:
@ -653,10 +699,12 @@ class RedisDocStatusStorage(DocStatusStorage):
try: try:
result.append(json.loads(result_data)) result.append(json.loads(result_data))
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error in get_by_ids: {e}") logger.error(
f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
)
continue continue
except Exception as e: except Exception as e:
logger.error(f"Error in get_by_ids: {e}") logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
return result return result
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
@ -668,7 +716,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000 cursor, match=f"{self.final_namespace}:*", count=1000
) )
if keys: if keys:
# Get all values in batch # Get all values in batch
@ -691,7 +739,7 @@ class RedisDocStatusStorage(DocStatusStorage):
if cursor == 0: if cursor == 0:
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting status counts: {e}") logger.error(f"[{self.workspace}] Error getting status counts: {e}")
return counts return counts
@ -706,7 +754,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000 cursor, match=f"{self.final_namespace}:*", count=1000
) )
if keys: if keys:
# Get all values in batch # Get all values in batch
@ -740,14 +788,14 @@ class RedisDocStatusStorage(DocStatusStorage):
result[doc_id] = DocProcessingStatus(**data) result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.error( logger.error(
f"Error processing document {key}: {e}" f"[{self.workspace}] Error processing document {key}: {e}"
) )
continue continue
if cursor == 0: if cursor == 0:
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting docs by status: {e}") logger.error(f"[{self.workspace}] Error getting docs by status: {e}")
return result return result
@ -762,7 +810,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000 cursor, match=f"{self.final_namespace}:*", count=1000
) )
if keys: if keys:
# Get all values in batch # Get all values in batch
@ -796,14 +844,14 @@ class RedisDocStatusStorage(DocStatusStorage):
result[doc_id] = DocProcessingStatus(**data) result[doc_id] = DocProcessingStatus(**data)
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.error( logger.error(
f"Error processing document {key}: {e}" f"[{self.workspace}] Error processing document {key}: {e}"
) )
continue continue
if cursor == 0: if cursor == 0:
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting docs by track_id: {e}") logger.error(f"[{self.workspace}] Error getting docs by track_id: {e}")
return result return result
@ -817,7 +865,9 @@ class RedisDocStatusStorage(DocStatusStorage):
if not data: if not data:
return return
logger.debug(f"Inserting {len(data)} records to {self.namespace}") logger.debug(
f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}"
)
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Ensure chunks_list field exists for new documents # Ensure chunks_list field exists for new documents
@ -827,20 +877,20 @@ class RedisDocStatusStorage(DocStatusStorage):
pipe = redis.pipeline() pipe = redis.pipeline()
for k, v in data.items(): for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) pipe.set(f"{self.final_namespace}:{k}", json.dumps(v))
await pipe.execute() await pipe.execute()
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error during upsert: {e}") logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}")
raise raise
@redis_retry @redis_retry
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
data = await redis.get(f"{self.namespace}:{id}") data = await redis.get(f"{self.final_namespace}:{id}")
return json.loads(data) if data else None return json.loads(data) if data else None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}") logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}")
return None return None
async def delete(self, doc_ids: list[str]) -> None: async def delete(self, doc_ids: list[str]) -> None:
@ -851,12 +901,12 @@ class RedisDocStatusStorage(DocStatusStorage):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
pipe = redis.pipeline() pipe = redis.pipeline()
for doc_id in doc_ids: for doc_id in doc_ids:
pipe.delete(f"{self.namespace}:{doc_id}") pipe.delete(f"{self.final_namespace}:{doc_id}")
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info( logger.info(
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}" f"[{self.workspace}] Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
) )
async def get_docs_paginated( async def get_docs_paginated(
@ -903,7 +953,7 @@ class RedisDocStatusStorage(DocStatusStorage):
cursor = 0 cursor = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(
cursor, match=f"{self.namespace}:*", count=1000 cursor, match=f"{self.final_namespace}:*", count=1000
) )
if keys: if keys:
# Get all values in batch # Get all values in batch
@ -950,7 +1000,7 @@ class RedisDocStatusStorage(DocStatusStorage):
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.error( logger.error(
f"Error processing document {key}: {e}" f"[{self.workspace}] Error processing document {key}: {e}"
) )
continue continue
@ -958,7 +1008,7 @@ class RedisDocStatusStorage(DocStatusStorage):
break break
except Exception as e: except Exception as e:
logger.error(f"Error getting paginated docs: {e}") logger.error(f"[{self.workspace}] Error getting paginated docs: {e}")
return [], 0 return [], 0
# Sort documents using the separate sort key # Sort documents using the separate sort key
@ -996,7 +1046,7 @@ class RedisDocStatusStorage(DocStatusStorage):
try: try:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix # Use SCAN to find all keys with the namespace prefix
pattern = f"{self.namespace}:*" pattern = f"{self.final_namespace}:*"
cursor = 0 cursor = 0
deleted_count = 0 deleted_count = 0
@ -1014,9 +1064,11 @@ class RedisDocStatusStorage(DocStatusStorage):
break break
logger.info( logger.info(
f"Dropped {deleted_count} doc status keys from {self.namespace}" f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error(f"Error dropping doc status {self.namespace}: {e}") logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import Iterable from typing import Iterable
# All namespace should not be changed
class NameSpace: class NameSpace:
KV_STORE_FULL_DOCS = "full_docs" KV_STORE_FULL_DOCS = "full_docs"
KV_STORE_TEXT_CHUNKS = "text_chunks" KV_STORE_TEXT_CHUNKS = "text_chunks"

View file

@ -1248,7 +1248,7 @@ async def merge_nodes_and_edges(
semaphore = asyncio.Semaphore(graph_max_async) semaphore = asyncio.Semaphore(graph_max_async)
# ===== Phase 1: Process all entities concurrently ===== # ===== Phase 1: Process all entities concurrently =====
log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})" log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})"
logger.info(log_message) logger.info(log_message)
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
@ -1312,7 +1312,7 @@ async def merge_nodes_and_edges(
processed_entities = [task.result() for task in entity_tasks] processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently ===== # ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})" log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message) logger.info(log_message)
async with pipeline_status_lock: async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message pipeline_status["latest_message"] = log_message
@ -1422,6 +1422,12 @@ async def merge_nodes_and_edges(
relation_pair = tuple(sorted([src_id, tgt_id])) relation_pair = tuple(sorted([src_id, tgt_id]))
final_relation_pairs.add(relation_pair) final_relation_pairs.add(relation_pair)
log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Update storage # Update storage
if final_entity_names: if final_entity_names:
await full_entities_storage.upsert( await full_entities_storage.upsert(