Merge pull request #1904 from HKUDS/optimize-doc-delete
feat(performance): Optimize Document Deletion with Entity/Relation Indexing
This commit is contained in:
commit
7de0276a57
12 changed files with 957 additions and 118 deletions
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0196"
|
__api_version__ = "0198"
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,7 @@ def create_app(args):
|
||||||
try:
|
try:
|
||||||
# Initialize database connections
|
# Initialize database connections
|
||||||
await rag.initialize_storages()
|
await rag.initialize_storages()
|
||||||
|
await rag.check_and_migrate_data()
|
||||||
|
|
||||||
await initialize_pipeline_status()
|
await initialize_pipeline_status()
|
||||||
pipeline_status = await get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
|
|
@ -401,7 +402,6 @@ def create_app(args):
|
||||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
enable_llm_cache=args.enable_llm_cache,
|
enable_llm_cache=args.enable_llm_cache,
|
||||||
rerank_model_func=rerank_model_func,
|
rerank_model_func=rerank_model_func,
|
||||||
auto_manage_storages_states=False,
|
|
||||||
max_parallel_insert=args.max_parallel_insert,
|
max_parallel_insert=args.max_parallel_insert,
|
||||||
max_graph_nodes=args.max_graph_nodes,
|
max_graph_nodes=args.max_graph_nodes,
|
||||||
addon_params={"language": args.summary_language},
|
addon_params={"language": args.summary_language},
|
||||||
|
|
@ -431,7 +431,6 @@ def create_app(args):
|
||||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||||
enable_llm_cache=args.enable_llm_cache,
|
enable_llm_cache=args.enable_llm_cache,
|
||||||
rerank_model_func=rerank_model_func,
|
rerank_model_func=rerank_model_func,
|
||||||
auto_manage_storages_states=False,
|
|
||||||
max_parallel_insert=args.max_parallel_insert,
|
max_parallel_insert=args.max_parallel_insert,
|
||||||
max_graph_nodes=args.max_graph_nodes,
|
max_graph_nodes=args.max_graph_nodes,
|
||||||
addon_params={"language": args.summary_language},
|
addon_params={"language": args.summary_language},
|
||||||
|
|
|
||||||
|
|
@ -1473,6 +1473,8 @@ def create_document_routes(
|
||||||
storages = [
|
storages = [
|
||||||
rag.text_chunks,
|
rag.text_chunks,
|
||||||
rag.full_docs,
|
rag.full_docs,
|
||||||
|
rag.full_entities,
|
||||||
|
rag.full_relations,
|
||||||
rag.entities_vdb,
|
rag.entities_vdb,
|
||||||
rag.relationships_vdb,
|
rag.relationships_vdb,
|
||||||
rag.chunks_vdb,
|
rag.chunks_vdb,
|
||||||
|
|
|
||||||
|
|
@ -654,6 +654,23 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||||
indicating whether the graph was truncated due to max_nodes limit
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
(Edge is bidirectional for some storage implementation; deduplication must be handled by the caller)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class DocStatus(str, Enum):
|
class DocStatus(str, Enum):
|
||||||
"""Document processing status"""
|
"""Document processing status"""
|
||||||
|
|
|
||||||
|
|
@ -997,3 +997,60 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.warning(f"Memgraph error during subgraph query: {str(e)}")
|
logger.warning(f"Memgraph error during subgraph query: {str(e)}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
if self._driver is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
|
)
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = f"""
|
||||||
|
MATCH (n:`{workspace_label}`)
|
||||||
|
RETURN n
|
||||||
|
"""
|
||||||
|
result = await session.run(query)
|
||||||
|
nodes = []
|
||||||
|
async for record in result:
|
||||||
|
node = record["n"]
|
||||||
|
node_dict = dict(node)
|
||||||
|
# Add node id (entity_id) to the dictionary for easier access
|
||||||
|
node_dict["id"] = node_dict.get("entity_id")
|
||||||
|
nodes.append(node_dict)
|
||||||
|
await result.consume()
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
if self._driver is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
|
)
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = f"""
|
||||||
|
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
||||||
|
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||||
|
"""
|
||||||
|
result = await session.run(query)
|
||||||
|
edges = []
|
||||||
|
async for record in result:
|
||||||
|
edge_properties = record["properties"]
|
||||||
|
edge_properties["source"] = record["source"]
|
||||||
|
edge_properties["target"] = record["target"]
|
||||||
|
edges.append(edge_properties)
|
||||||
|
await result.consume()
|
||||||
|
return edges
|
||||||
|
|
|
||||||
|
|
@ -1508,6 +1508,36 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
logger.debug(f"Successfully deleted edges: {edges}")
|
logger.debug(f"Successfully deleted edges: {edges}")
|
||||||
|
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
cursor = self.collection.find({})
|
||||||
|
nodes = []
|
||||||
|
async for node in cursor:
|
||||||
|
node_dict = dict(node)
|
||||||
|
# Add node id (entity_id) to the dictionary for easier access
|
||||||
|
node_dict["id"] = node_dict.get("_id")
|
||||||
|
nodes.append(node_dict)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
cursor = self.edge_collection.find({})
|
||||||
|
edges = []
|
||||||
|
async for edge in cursor:
|
||||||
|
edge_dict = dict(edge)
|
||||||
|
edge_dict["source"] = edge_dict.get("source_node_id")
|
||||||
|
edge_dict["target"] = edge_dict.get("target_node_id")
|
||||||
|
edges.append(edge_dict)
|
||||||
|
return edges
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage by removing all documents in the collection.
|
"""Drop the storage by removing all documents in the collection.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1400,6 +1400,55 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(f"Error during edge deletion: {str(e)}")
|
logger.error(f"Error during edge deletion: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = f"""
|
||||||
|
MATCH (n:`{workspace_label}`)
|
||||||
|
RETURN n
|
||||||
|
"""
|
||||||
|
result = await session.run(query)
|
||||||
|
nodes = []
|
||||||
|
async for record in result:
|
||||||
|
node = record["n"]
|
||||||
|
node_dict = dict(node)
|
||||||
|
# Add node id (entity_id) to the dictionary for easier access
|
||||||
|
node_dict["id"] = node_dict.get("entity_id")
|
||||||
|
nodes.append(node_dict)
|
||||||
|
await result.consume()
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = f"""
|
||||||
|
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
|
||||||
|
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||||
|
"""
|
||||||
|
result = await session.run(query)
|
||||||
|
edges = []
|
||||||
|
async for record in result:
|
||||||
|
edge_properties = record["properties"]
|
||||||
|
edge_properties["source"] = record["source"]
|
||||||
|
edge_properties["target"] = record["target"]
|
||||||
|
edges.append(edge_properties)
|
||||||
|
await result.consume()
|
||||||
|
return edges
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all data from current workspace storage and clean up resources
|
"""Drop all data from current workspace storage and clean up resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -393,6 +393,35 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
matching_edges.append(edge_data_with_nodes)
|
matching_edges.append(edge_data_with_nodes)
|
||||||
return matching_edges
|
return matching_edges
|
||||||
|
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
graph = await self._get_graph()
|
||||||
|
all_nodes = []
|
||||||
|
for node_id, node_data in graph.nodes(data=True):
|
||||||
|
node_data_with_id = node_data.copy()
|
||||||
|
node_data_with_id["id"] = node_id
|
||||||
|
all_nodes.append(node_data_with_id)
|
||||||
|
return all_nodes
|
||||||
|
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
graph = await self._get_graph()
|
||||||
|
all_edges = []
|
||||||
|
for u, v, edge_data in graph.edges(data=True):
|
||||||
|
edge_data_with_nodes = edge_data.copy()
|
||||||
|
edge_data_with_nodes["source"] = u
|
||||||
|
edge_data_with_nodes["target"] = v
|
||||||
|
all_edges.append(edge_data_with_nodes)
|
||||||
|
return all_edges
|
||||||
|
|
||||||
async def index_done_callback(self) -> bool:
|
async def index_done_callback(self) -> bool:
|
||||||
"""Save data to disk"""
|
"""Save data to disk"""
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
|
|
||||||
|
|
@ -920,6 +920,80 @@ class PostgreSQLDB:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL, Failed to create pagination indexes: {e}")
|
logger.error(f"PostgreSQL, Failed to create pagination indexes: {e}")
|
||||||
|
|
||||||
|
# Migrate to ensure new tables LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS exist
|
||||||
|
try:
|
||||||
|
await self._migrate_create_full_entities_relations_tables()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"PostgreSQL, Failed to create full entities/relations tables: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _migrate_create_full_entities_relations_tables(self):
|
||||||
|
"""Create LIGHTRAG_FULL_ENTITIES and LIGHTRAG_FULL_RELATIONS tables if they don't exist"""
|
||||||
|
tables_to_check = [
|
||||||
|
{
|
||||||
|
"name": "LIGHTRAG_FULL_ENTITIES",
|
||||||
|
"ddl": TABLES["LIGHTRAG_FULL_ENTITIES"]["ddl"],
|
||||||
|
"description": "Full entities storage table",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "LIGHTRAG_FULL_RELATIONS",
|
||||||
|
"ddl": TABLES["LIGHTRAG_FULL_RELATIONS"]["ddl"],
|
||||||
|
"description": "Full relations storage table",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_info in tables_to_check:
|
||||||
|
table_name = table_info["name"]
|
||||||
|
try:
|
||||||
|
# Check if table exists
|
||||||
|
check_table_sql = """
|
||||||
|
SELECT table_name
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_name = $1
|
||||||
|
AND table_schema = 'public'
|
||||||
|
"""
|
||||||
|
|
||||||
|
table_exists = await self.query(
|
||||||
|
check_table_sql, {"table_name": table_name.lower()}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not table_exists:
|
||||||
|
logger.info(f"Creating table {table_name}")
|
||||||
|
await self.execute(table_info["ddl"])
|
||||||
|
logger.info(
|
||||||
|
f"Successfully created {table_info['description']}: {table_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create basic indexes for the new table
|
||||||
|
try:
|
||||||
|
# Create index for id column
|
||||||
|
index_name = f"idx_{table_name.lower()}_id"
|
||||||
|
create_index_sql = (
|
||||||
|
f"CREATE INDEX {index_name} ON {table_name}(id)"
|
||||||
|
)
|
||||||
|
await self.execute(create_index_sql)
|
||||||
|
logger.info(f"Created index {index_name} on table {table_name}")
|
||||||
|
|
||||||
|
# Create composite index for (workspace, id) columns
|
||||||
|
composite_index_name = f"idx_{table_name.lower()}_workspace_id"
|
||||||
|
create_composite_index_sql = f"CREATE INDEX {composite_index_name} ON {table_name}(workspace, id)"
|
||||||
|
await self.execute(create_composite_index_sql)
|
||||||
|
logger.info(
|
||||||
|
f"Created composite index {composite_index_name} on table {table_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to create indexes for table {table_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(f"Table {table_name} already exists")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create table {table_name}: {e}")
|
||||||
|
|
||||||
async def _create_pagination_indexes(self):
|
async def _create_pagination_indexes(self):
|
||||||
"""Create indexes to optimize pagination queries for LIGHTRAG_DOC_STATUS"""
|
"""Create indexes to optimize pagination queries for LIGHTRAG_DOC_STATUS"""
|
||||||
indexes = [
|
indexes = [
|
||||||
|
|
@ -1233,6 +1307,46 @@ class PGKVStorage(BaseKVStorage):
|
||||||
processed_results[row["id"]] = row
|
processed_results[row["id"]] = row
|
||||||
return processed_results
|
return processed_results
|
||||||
|
|
||||||
|
# For FULL_ENTITIES namespace, parse entity_names JSON string back to list
|
||||||
|
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
processed_results = {}
|
||||||
|
for row in results:
|
||||||
|
entity_names = row.get("entity_names", [])
|
||||||
|
if isinstance(entity_names, str):
|
||||||
|
try:
|
||||||
|
entity_names = json.loads(entity_names)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
entity_names = []
|
||||||
|
row["entity_names"] = entity_names
|
||||||
|
create_time = row.get("create_time", 0)
|
||||||
|
update_time = row.get("update_time", 0)
|
||||||
|
row["create_time"] = create_time
|
||||||
|
row["update_time"] = (
|
||||||
|
create_time if update_time == 0 else update_time
|
||||||
|
)
|
||||||
|
processed_results[row["id"]] = row
|
||||||
|
return processed_results
|
||||||
|
|
||||||
|
# For FULL_RELATIONS namespace, parse relation_pairs JSON string back to list
|
||||||
|
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
|
||||||
|
processed_results = {}
|
||||||
|
for row in results:
|
||||||
|
relation_pairs = row.get("relation_pairs", [])
|
||||||
|
if isinstance(relation_pairs, str):
|
||||||
|
try:
|
||||||
|
relation_pairs = json.loads(relation_pairs)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
relation_pairs = []
|
||||||
|
row["relation_pairs"] = relation_pairs
|
||||||
|
create_time = row.get("create_time", 0)
|
||||||
|
update_time = row.get("update_time", 0)
|
||||||
|
row["create_time"] = create_time
|
||||||
|
row["update_time"] = (
|
||||||
|
create_time if update_time == 0 else update_time
|
||||||
|
)
|
||||||
|
processed_results[row["id"]] = row
|
||||||
|
return processed_results
|
||||||
|
|
||||||
# For other namespaces, return as-is
|
# For other namespaces, return as-is
|
||||||
return {row["id"]: row for row in results}
|
return {row["id"]: row for row in results}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1277,6 +1391,36 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"update_time": create_time if update_time == 0 else update_time,
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Special handling for FULL_ENTITIES namespace
|
||||||
|
if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
# Parse entity_names JSON string back to list
|
||||||
|
entity_names = response.get("entity_names", [])
|
||||||
|
if isinstance(entity_names, str):
|
||||||
|
try:
|
||||||
|
entity_names = json.loads(entity_names)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
entity_names = []
|
||||||
|
response["entity_names"] = entity_names
|
||||||
|
create_time = response.get("create_time", 0)
|
||||||
|
update_time = response.get("update_time", 0)
|
||||||
|
response["create_time"] = create_time
|
||||||
|
response["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
|
# Special handling for FULL_RELATIONS namespace
|
||||||
|
if response and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
|
||||||
|
# Parse relation_pairs JSON string back to list
|
||||||
|
relation_pairs = response.get("relation_pairs", [])
|
||||||
|
if isinstance(relation_pairs, str):
|
||||||
|
try:
|
||||||
|
relation_pairs = json.loads(relation_pairs)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
relation_pairs = []
|
||||||
|
response["relation_pairs"] = relation_pairs
|
||||||
|
create_time = response.get("create_time", 0)
|
||||||
|
update_time = response.get("update_time", 0)
|
||||||
|
response["create_time"] = create_time
|
||||||
|
response["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
return response if response else None
|
return response if response else None
|
||||||
|
|
||||||
# Query by id
|
# Query by id
|
||||||
|
|
@ -1325,6 +1469,38 @@ class PGKVStorage(BaseKVStorage):
|
||||||
processed_results.append(processed_row)
|
processed_results.append(processed_row)
|
||||||
return processed_results
|
return processed_results
|
||||||
|
|
||||||
|
# Special handling for FULL_ENTITIES namespace
|
||||||
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
for result in results:
|
||||||
|
# Parse entity_names JSON string back to list
|
||||||
|
entity_names = result.get("entity_names", [])
|
||||||
|
if isinstance(entity_names, str):
|
||||||
|
try:
|
||||||
|
entity_names = json.loads(entity_names)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
entity_names = []
|
||||||
|
result["entity_names"] = entity_names
|
||||||
|
create_time = result.get("create_time", 0)
|
||||||
|
update_time = result.get("update_time", 0)
|
||||||
|
result["create_time"] = create_time
|
||||||
|
result["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
|
# Special handling for FULL_RELATIONS namespace
|
||||||
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
|
||||||
|
for result in results:
|
||||||
|
# Parse relation_pairs JSON string back to list
|
||||||
|
relation_pairs = result.get("relation_pairs", [])
|
||||||
|
if isinstance(relation_pairs, str):
|
||||||
|
try:
|
||||||
|
relation_pairs = json.loads(relation_pairs)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
relation_pairs = []
|
||||||
|
result["relation_pairs"] = relation_pairs
|
||||||
|
create_time = result.get("create_time", 0)
|
||||||
|
update_time = result.get("update_time", 0)
|
||||||
|
result["create_time"] = create_time
|
||||||
|
result["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
return results if results else []
|
return results if results else []
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
|
|
@ -1397,6 +1573,34 @@ class PGKVStorage(BaseKVStorage):
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.db.execute(upsert_sql, _data)
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
# Get current UTC time and convert to naive datetime for database storage
|
||||||
|
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
for k, v in data.items():
|
||||||
|
upsert_sql = SQL_TEMPLATES["upsert_full_entities"]
|
||||||
|
_data = {
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
"id": k,
|
||||||
|
"entity_names": json.dumps(v["entity_names"]),
|
||||||
|
"count": v["count"],
|
||||||
|
"create_time": current_time,
|
||||||
|
"update_time": current_time,
|
||||||
|
}
|
||||||
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_RELATIONS):
|
||||||
|
# Get current UTC time and convert to naive datetime for database storage
|
||||||
|
current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
for k, v in data.items():
|
||||||
|
upsert_sql = SQL_TEMPLATES["upsert_full_relations"]
|
||||||
|
_data = {
|
||||||
|
"workspace": self.db.workspace,
|
||||||
|
"id": k,
|
||||||
|
"relation_pairs": json.dumps(v["relation_pairs"]),
|
||||||
|
"count": v["count"],
|
||||||
|
"create_time": current_time,
|
||||||
|
"update_time": current_time,
|
||||||
|
}
|
||||||
|
await self.db.execute(upsert_sql, _data)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# PG handles persistence automatically
|
# PG handles persistence automatically
|
||||||
|
|
@ -3669,6 +3873,67 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
return kg
|
return kg
|
||||||
|
|
||||||
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all nodes, where each node is a dictionary of its properties
|
||||||
|
"""
|
||||||
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
MATCH (n:base)
|
||||||
|
RETURN n
|
||||||
|
$$) AS (n agtype)"""
|
||||||
|
|
||||||
|
results = await self._query(query)
|
||||||
|
nodes = []
|
||||||
|
for result in results:
|
||||||
|
if result["n"]:
|
||||||
|
node_dict = result["n"]["properties"]
|
||||||
|
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(node_dict, str):
|
||||||
|
try:
|
||||||
|
node_dict = json.loads(node_dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse node string: {node_dict}")
|
||||||
|
|
||||||
|
# Add node id (entity_id) to the dictionary for easier access
|
||||||
|
node_dict["id"] = node_dict.get("entity_id")
|
||||||
|
nodes.append(node_dict)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_all_edges(self) -> list[dict]:
|
||||||
|
"""Get all edges in the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all edges, where each edge is a dictionary of its properties
|
||||||
|
(The edge is bidirectional; deduplication must be handled by the caller)
|
||||||
|
"""
|
||||||
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
MATCH (a:base)-[r]-(b:base)
|
||||||
|
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||||
|
$$) AS (source text, target text, properties agtype)"""
|
||||||
|
|
||||||
|
results = await self._query(query)
|
||||||
|
edges = []
|
||||||
|
for result in results:
|
||||||
|
edge_properties = result["properties"]
|
||||||
|
|
||||||
|
# Process string result, parse it to JSON dictionary
|
||||||
|
if isinstance(edge_properties, str):
|
||||||
|
try:
|
||||||
|
edge_properties = json.loads(edge_properties)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse edge properties string: {edge_properties}"
|
||||||
|
)
|
||||||
|
edge_properties = {}
|
||||||
|
|
||||||
|
edge_properties["source"] = result["source"]
|
||||||
|
edge_properties["target"] = result["target"]
|
||||||
|
edges.append(edge_properties)
|
||||||
|
return edges
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
try:
|
try:
|
||||||
|
|
@ -3687,14 +3952,18 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Order matters! More specific namespaces (e.g., "full_entities") must come before
|
||||||
|
# more general ones (e.g., "entities") because is_namespace() uses endswith() matching
|
||||||
NAMESPACE_TABLE_MAP = {
|
NAMESPACE_TABLE_MAP = {
|
||||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||||
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
||||||
|
NameSpace.KV_STORE_FULL_ENTITIES: "LIGHTRAG_FULL_ENTITIES",
|
||||||
|
NameSpace.KV_STORE_FULL_RELATIONS: "LIGHTRAG_FULL_RELATIONS",
|
||||||
|
NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
|
||||||
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
|
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS",
|
||||||
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
|
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
|
||||||
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
|
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
|
||||||
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
|
NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
|
||||||
NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3807,6 +4076,28 @@ TABLES = {
|
||||||
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
|
CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
|
||||||
)"""
|
)"""
|
||||||
},
|
},
|
||||||
|
"LIGHTRAG_FULL_ENTITIES": {
|
||||||
|
"ddl": """CREATE TABLE LIGHTRAG_FULL_ENTITIES (
|
||||||
|
id VARCHAR(255),
|
||||||
|
workspace VARCHAR(255),
|
||||||
|
entity_names JSONB,
|
||||||
|
count INTEGER,
|
||||||
|
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
CONSTRAINT LIGHTRAG_FULL_ENTITIES_PK PRIMARY KEY (workspace, id)
|
||||||
|
)"""
|
||||||
|
},
|
||||||
|
"LIGHTRAG_FULL_RELATIONS": {
|
||||||
|
"ddl": """CREATE TABLE LIGHTRAG_FULL_RELATIONS (
|
||||||
|
id VARCHAR(255),
|
||||||
|
workspace VARCHAR(255),
|
||||||
|
relation_pairs JSONB,
|
||||||
|
count INTEGER,
|
||||||
|
create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
CONSTRAINT LIGHTRAG_FULL_RELATIONS_PK PRIMARY KEY (workspace, id)
|
||||||
|
)"""
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3845,6 +4136,26 @@ SQL_TEMPLATES = {
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
||||||
""",
|
""",
|
||||||
|
"get_by_id_full_entities": """SELECT id, entity_names, count,
|
||||||
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
|
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id=$2
|
||||||
|
""",
|
||||||
|
"get_by_id_full_relations": """SELECT id, relation_pairs, count,
|
||||||
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
|
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id=$2
|
||||||
|
""",
|
||||||
|
"get_by_ids_full_entities": """SELECT id, entity_names, count,
|
||||||
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
|
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
|
||||||
|
""",
|
||||||
|
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
|
||||||
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
|
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
|
||||||
|
""",
|
||||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
|
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
|
|
@ -3874,6 +4185,22 @@ SQL_TEMPLATES = {
|
||||||
llm_cache_list=EXCLUDED.llm_cache_list,
|
llm_cache_list=EXCLUDED.llm_cache_list,
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
|
"upsert_full_entities": """INSERT INTO LIGHTRAG_FULL_ENTITIES (workspace, id, entity_names, count,
|
||||||
|
create_time, update_time)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
|
SET entity_names=EXCLUDED.entity_names,
|
||||||
|
count=EXCLUDED.count,
|
||||||
|
update_time = EXCLUDED.update_time
|
||||||
|
""",
|
||||||
|
"upsert_full_relations": """INSERT INTO LIGHTRAG_FULL_RELATIONS (workspace, id, relation_pairs, count,
|
||||||
|
create_time, update_time)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
|
SET relation_pairs=EXCLUDED.relation_pairs,
|
||||||
|
count=EXCLUDED.count,
|
||||||
|
update_time = EXCLUDED.update_time
|
||||||
|
""",
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||||
|
|
|
||||||
|
|
@ -334,12 +334,10 @@ class LightRAG:
|
||||||
# Storages Management
|
# Storages Management
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
auto_manage_storages_states: bool = field(default=True)
|
# TODO: Deprecated (LightRAG will never initialize storage automatically on creation,and finalize should be call before destroying)
|
||||||
|
auto_manage_storages_states: bool = field(default=False)
|
||||||
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
|
||||||
|
|
||||||
# Storages Management
|
|
||||||
# ---
|
|
||||||
|
|
||||||
cosine_better_than_threshold: float = field(
|
cosine_better_than_threshold: float = field(
|
||||||
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
||||||
)
|
)
|
||||||
|
|
@ -453,14 +451,26 @@ class LightRAG:
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
|
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
|
||||||
|
workspace=self.workspace,
|
||||||
|
embedding_func=self.embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=NameSpace.KV_STORE_FULL_DOCS,
|
namespace=NameSpace.KV_STORE_FULL_DOCS,
|
||||||
workspace=self.workspace,
|
workspace=self.workspace,
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
self.full_entities: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
namespace=NameSpace.KV_STORE_TEXT_CHUNKS,
|
namespace=NameSpace.KV_STORE_FULL_ENTITIES,
|
||||||
|
workspace=self.workspace,
|
||||||
|
embedding_func=self.embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.full_relations: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
||||||
|
namespace=NameSpace.KV_STORE_FULL_RELATIONS,
|
||||||
workspace=self.workspace,
|
workspace=self.workspace,
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
@ -519,32 +529,6 @@ class LightRAG:
|
||||||
|
|
||||||
self._storages_status = StoragesStatus.CREATED
|
self._storages_status = StoragesStatus.CREATED
|
||||||
|
|
||||||
if self.auto_manage_storages_states:
|
|
||||||
self._run_async_safely(self.initialize_storages, "Storage Initialization")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.auto_manage_storages_states:
|
|
||||||
self._run_async_safely(self.finalize_storages, "Storage Finalization")
|
|
||||||
|
|
||||||
def _run_async_safely(self, async_func, action_name=""):
|
|
||||||
"""Safely execute an async function, avoiding event loop conflicts."""
|
|
||||||
try:
|
|
||||||
loop = always_get_an_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
task = loop.create_task(async_func())
|
|
||||||
task.add_done_callback(
|
|
||||||
lambda t: logger.info(f"{action_name} completed!")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(async_func())
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
f"No running event loop, creating a new loop for {action_name}."
|
|
||||||
)
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
loop.run_until_complete(async_func())
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
async def initialize_storages(self):
|
async def initialize_storages(self):
|
||||||
"""Asynchronously initialize the storages"""
|
"""Asynchronously initialize the storages"""
|
||||||
if self._storages_status == StoragesStatus.CREATED:
|
if self._storages_status == StoragesStatus.CREATED:
|
||||||
|
|
@ -553,6 +537,8 @@ class LightRAG:
|
||||||
for storage in (
|
for storage in (
|
||||||
self.full_docs,
|
self.full_docs,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
|
self.full_entities,
|
||||||
|
self.full_relations,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
|
|
@ -569,27 +555,207 @@ class LightRAG:
|
||||||
logger.debug("All storage types initialized")
|
logger.debug("All storage types initialized")
|
||||||
|
|
||||||
async def finalize_storages(self):
|
async def finalize_storages(self):
|
||||||
"""Asynchronously finalize the storages"""
|
"""Asynchronously finalize the storages with improved error handling"""
|
||||||
if self._storages_status == StoragesStatus.INITIALIZED:
|
if self._storages_status == StoragesStatus.INITIALIZED:
|
||||||
tasks = []
|
storages = [
|
||||||
|
("full_docs", self.full_docs),
|
||||||
|
("text_chunks", self.text_chunks),
|
||||||
|
("full_entities", self.full_entities),
|
||||||
|
("full_relations", self.full_relations),
|
||||||
|
("entities_vdb", self.entities_vdb),
|
||||||
|
("relationships_vdb", self.relationships_vdb),
|
||||||
|
("chunks_vdb", self.chunks_vdb),
|
||||||
|
("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
|
||||||
|
("llm_response_cache", self.llm_response_cache),
|
||||||
|
("doc_status", self.doc_status),
|
||||||
|
]
|
||||||
|
|
||||||
for storage in (
|
# Finalize each storage individually to ensure one failure doesn't prevent others from closing
|
||||||
self.full_docs,
|
successful_finalizations = []
|
||||||
self.text_chunks,
|
failed_finalizations = []
|
||||||
self.entities_vdb,
|
|
||||||
self.relationships_vdb,
|
for storage_name, storage in storages:
|
||||||
self.chunks_vdb,
|
|
||||||
self.chunk_entity_relation_graph,
|
|
||||||
self.llm_response_cache,
|
|
||||||
self.doc_status,
|
|
||||||
):
|
|
||||||
if storage:
|
if storage:
|
||||||
tasks.append(storage.finalize())
|
try:
|
||||||
|
await storage.finalize()
|
||||||
|
successful_finalizations.append(storage_name)
|
||||||
|
logger.debug(f"Successfully finalized {storage_name}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to finalize {storage_name}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
failed_finalizations.append(storage_name)
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
# Log summary of finalization results
|
||||||
|
if successful_finalizations:
|
||||||
|
logger.info(
|
||||||
|
f"Successfully finalized {len(successful_finalizations)} storages"
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_finalizations:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to finalize {len(failed_finalizations)} storages: {', '.join(failed_finalizations)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("All storages finalized successfully")
|
||||||
|
|
||||||
self._storages_status = StoragesStatus.FINALIZED
|
self._storages_status = StoragesStatus.FINALIZED
|
||||||
logger.debug("Finalized Storages")
|
|
||||||
|
async def check_and_migrate_data(self):
|
||||||
|
"""Check if data migration is needed and perform migration if necessary"""
|
||||||
|
try:
|
||||||
|
# Check if migration is needed:
|
||||||
|
# 1. chunk_entity_relation_graph has entities and relations (count > 0)
|
||||||
|
# 2. full_entities and full_relations are empty
|
||||||
|
|
||||||
|
# Get all entity labels from graph
|
||||||
|
all_entity_labels = await self.chunk_entity_relation_graph.get_all_labels()
|
||||||
|
|
||||||
|
if not all_entity_labels:
|
||||||
|
logger.debug("No entities found in graph, skipping migration check")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if full_entities and full_relations are empty
|
||||||
|
# Get all processed documents to check their entity/relation data
|
||||||
|
try:
|
||||||
|
processed_docs = await self.doc_status.get_docs_by_status(
|
||||||
|
DocStatus.PROCESSED
|
||||||
|
)
|
||||||
|
|
||||||
|
if not processed_docs:
|
||||||
|
logger.debug("No processed documents found, skipping migration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check first few documents to see if they have full_entities/full_relations data
|
||||||
|
migration_needed = True
|
||||||
|
checked_count = 0
|
||||||
|
max_check = min(5, len(processed_docs)) # Check up to 5 documents
|
||||||
|
|
||||||
|
for doc_id in list(processed_docs.keys())[:max_check]:
|
||||||
|
checked_count += 1
|
||||||
|
entity_data = await self.full_entities.get_by_id(doc_id)
|
||||||
|
relation_data = await self.full_relations.get_by_id(doc_id)
|
||||||
|
|
||||||
|
if entity_data or relation_data:
|
||||||
|
migration_needed = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not migration_needed:
|
||||||
|
logger.debug(
|
||||||
|
"Full entities/relations data already exists, no migration needed"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Data migration needed: found {len(all_entity_labels)} entities in graph but no full_entities/full_relations data"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform migration
|
||||||
|
await self._migrate_entity_relation_data(processed_docs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during migration check: {e}")
|
||||||
|
# Don't raise the error, just log it to avoid breaking initialization
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in data migration check: {e}")
|
||||||
|
# Don't raise the error to avoid breaking initialization
|
||||||
|
|
||||||
|
async def _migrate_entity_relation_data(self, processed_docs: dict):
|
||||||
|
"""Migrate existing entity and relation data to full_entities and full_relations storage"""
|
||||||
|
logger.info(f"Starting data migration for {len(processed_docs)} documents")
|
||||||
|
|
||||||
|
# Create mapping from chunk_id to doc_id
|
||||||
|
chunk_to_doc = {}
|
||||||
|
for doc_id, doc_status in processed_docs.items():
|
||||||
|
chunk_ids = (
|
||||||
|
doc_status.chunks_list
|
||||||
|
if hasattr(doc_status, "chunks_list") and doc_status.chunks_list
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
for chunk_id in chunk_ids:
|
||||||
|
chunk_to_doc[chunk_id] = doc_id
|
||||||
|
|
||||||
|
# Initialize document entity and relation mappings
|
||||||
|
doc_entities = {} # doc_id -> set of entity_names
|
||||||
|
doc_relations = {} # doc_id -> set of relation_pairs (as tuples)
|
||||||
|
|
||||||
|
# Get all nodes and edges from graph
|
||||||
|
all_nodes = await self.chunk_entity_relation_graph.get_all_nodes()
|
||||||
|
all_edges = await self.chunk_entity_relation_graph.get_all_edges()
|
||||||
|
|
||||||
|
# Process all nodes once
|
||||||
|
for node in all_nodes:
|
||||||
|
if "source_id" in node:
|
||||||
|
entity_id = node.get("entity_id") or node.get("id")
|
||||||
|
if not entity_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get chunk IDs from source_id
|
||||||
|
source_ids = node["source_id"].split(GRAPH_FIELD_SEP)
|
||||||
|
|
||||||
|
# Find which documents this entity belongs to
|
||||||
|
for chunk_id in source_ids:
|
||||||
|
doc_id = chunk_to_doc.get(chunk_id)
|
||||||
|
if doc_id:
|
||||||
|
if doc_id not in doc_entities:
|
||||||
|
doc_entities[doc_id] = set()
|
||||||
|
doc_entities[doc_id].add(entity_id)
|
||||||
|
|
||||||
|
# Process all edges once
|
||||||
|
for edge in all_edges:
|
||||||
|
if "source_id" in edge:
|
||||||
|
src = edge.get("source")
|
||||||
|
tgt = edge.get("target")
|
||||||
|
if not src or not tgt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get chunk IDs from source_id
|
||||||
|
source_ids = edge["source_id"].split(GRAPH_FIELD_SEP)
|
||||||
|
|
||||||
|
# Find which documents this relation belongs to
|
||||||
|
for chunk_id in source_ids:
|
||||||
|
doc_id = chunk_to_doc.get(chunk_id)
|
||||||
|
if doc_id:
|
||||||
|
if doc_id not in doc_relations:
|
||||||
|
doc_relations[doc_id] = set()
|
||||||
|
# Use tuple for set operations, convert to list later
|
||||||
|
doc_relations[doc_id].add(tuple(sorted((src, tgt))))
|
||||||
|
|
||||||
|
# Store the results in full_entities and full_relations
|
||||||
|
migration_count = 0
|
||||||
|
|
||||||
|
# Store entities
|
||||||
|
if doc_entities:
|
||||||
|
entities_data = {}
|
||||||
|
for doc_id, entity_set in doc_entities.items():
|
||||||
|
entities_data[doc_id] = {
|
||||||
|
"entity_names": list(entity_set),
|
||||||
|
"count": len(entity_set),
|
||||||
|
}
|
||||||
|
await self.full_entities.upsert(entities_data)
|
||||||
|
|
||||||
|
# Store relations
|
||||||
|
if doc_relations:
|
||||||
|
relations_data = {}
|
||||||
|
for doc_id, relation_set in doc_relations.items():
|
||||||
|
# Convert tuples back to lists
|
||||||
|
relations_data[doc_id] = {
|
||||||
|
"relation_pairs": [list(pair) for pair in relation_set],
|
||||||
|
"count": len(relation_set),
|
||||||
|
}
|
||||||
|
await self.full_relations.upsert(relations_data)
|
||||||
|
|
||||||
|
migration_count = len(
|
||||||
|
set(list(doc_entities.keys()) + list(doc_relations.keys()))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist the migrated data
|
||||||
|
await self.full_entities.index_done_callback()
|
||||||
|
await self.full_relations.index_done_callback()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Data migration completed: migrated {migration_count} documents with entities/relations"
|
||||||
|
)
|
||||||
|
|
||||||
async def get_graph_labels(self):
|
async def get_graph_labels(self):
|
||||||
text = await self.chunk_entity_relation_graph.get_all_labels()
|
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||||
|
|
@ -1229,6 +1395,9 @@ class LightRAG:
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
relationships_vdb=self.relationships_vdb,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
|
full_entities_storage=self.full_entities,
|
||||||
|
full_relations_storage=self.full_relations,
|
||||||
|
doc_id=doc_id,
|
||||||
pipeline_status=pipeline_status,
|
pipeline_status=pipeline_status,
|
||||||
pipeline_status_lock=pipeline_status_lock,
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
|
|
@ -1401,6 +1570,8 @@ class LightRAG:
|
||||||
self.full_docs,
|
self.full_docs,
|
||||||
self.doc_status,
|
self.doc_status,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
|
self.full_entities,
|
||||||
|
self.full_relations,
|
||||||
self.llm_response_cache,
|
self.llm_response_cache,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
|
|
@ -1959,21 +2130,54 @@ class LightRAG:
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||||
async with graph_db_lock:
|
async with graph_db_lock:
|
||||||
try:
|
try:
|
||||||
# Get all affected nodes and edges in batch
|
# Get affected entities and relations from full_entities and full_relations storage
|
||||||
# logger.info(
|
doc_entities_data = await self.full_entities.get_by_id(doc_id)
|
||||||
# f"Analyzing affected entities and relationships for {len(chunk_ids)} chunks"
|
doc_relations_data = await self.full_relations.get_by_id(doc_id)
|
||||||
# )
|
|
||||||
affected_nodes = (
|
|
||||||
await self.chunk_entity_relation_graph.get_nodes_by_chunk_ids(
|
|
||||||
list(chunk_ids)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
affected_edges = (
|
affected_nodes = []
|
||||||
await self.chunk_entity_relation_graph.get_edges_by_chunk_ids(
|
affected_edges = []
|
||||||
list(chunk_ids)
|
|
||||||
|
# Get entity data from graph storage using entity names from full_entities
|
||||||
|
if doc_entities_data and "entity_names" in doc_entities_data:
|
||||||
|
entity_names = doc_entities_data["entity_names"]
|
||||||
|
# get_nodes_batch returns dict[str, dict], need to convert to list[dict]
|
||||||
|
nodes_dict = (
|
||||||
|
await self.chunk_entity_relation_graph.get_nodes_batch(
|
||||||
|
entity_names
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
for entity_name in entity_names:
|
||||||
|
node_data = nodes_dict.get(entity_name)
|
||||||
|
if node_data:
|
||||||
|
# Ensure compatibility with existing logic that expects "id" field
|
||||||
|
if "id" not in node_data:
|
||||||
|
node_data["id"] = entity_name
|
||||||
|
affected_nodes.append(node_data)
|
||||||
|
|
||||||
|
# Get relation data from graph storage using relation pairs from full_relations
|
||||||
|
if doc_relations_data and "relation_pairs" in doc_relations_data:
|
||||||
|
relation_pairs = doc_relations_data["relation_pairs"]
|
||||||
|
edge_pairs_dicts = [
|
||||||
|
{"src": pair[0], "tgt": pair[1]} for pair in relation_pairs
|
||||||
|
]
|
||||||
|
# get_edges_batch returns dict[tuple[str, str], dict], need to convert to list[dict]
|
||||||
|
edges_dict = (
|
||||||
|
await self.chunk_entity_relation_graph.get_edges_batch(
|
||||||
|
edge_pairs_dicts
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for pair in relation_pairs:
|
||||||
|
src, tgt = pair[0], pair[1]
|
||||||
|
edge_key = (src, tgt)
|
||||||
|
edge_data = edges_dict.get(edge_key)
|
||||||
|
if edge_data:
|
||||||
|
# Ensure compatibility with existing logic that expects "source" and "target" fields
|
||||||
|
if "source" not in edge_data:
|
||||||
|
edge_data["source"] = src
|
||||||
|
if "target" not in edge_data:
|
||||||
|
edge_data["target"] = tgt
|
||||||
|
affected_edges.append(edge_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to analyze affected graph elements: {e}")
|
logger.error(f"Failed to analyze affected graph elements: {e}")
|
||||||
|
|
@ -2125,7 +2329,17 @@ class LightRAG:
|
||||||
f"Failed to rebuild knowledge graph: {e}"
|
f"Failed to rebuild knowledge graph: {e}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# 9. Delete original document and status
|
# 9. Delete from full_entities and full_relations storage
|
||||||
|
try:
|
||||||
|
await self.full_entities.delete([doc_id])
|
||||||
|
await self.full_relations.delete([doc_id])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete from full_entities/full_relations: {e}")
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to delete from full_entities/full_relations: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# 10. Delete original document and status
|
||||||
try:
|
try:
|
||||||
await self.full_docs.delete([doc_id])
|
await self.full_docs.delete([doc_id])
|
||||||
await self.doc_status.delete([doc_id])
|
await self.doc_status.delete([doc_id])
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ class NameSpace:
|
||||||
KV_STORE_FULL_DOCS = "full_docs"
|
KV_STORE_FULL_DOCS = "full_docs"
|
||||||
KV_STORE_TEXT_CHUNKS = "text_chunks"
|
KV_STORE_TEXT_CHUNKS = "text_chunks"
|
||||||
KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache"
|
KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache"
|
||||||
|
KV_STORE_FULL_ENTITIES = "full_entities"
|
||||||
|
KV_STORE_FULL_RELATIONS = "full_relations"
|
||||||
|
|
||||||
VECTOR_STORE_ENTITIES = "entities"
|
VECTOR_STORE_ENTITIES = "entities"
|
||||||
VECTOR_STORE_RELATIONSHIPS = "relationships"
|
VECTOR_STORE_RELATIONSHIPS = "relationships"
|
||||||
|
|
|
||||||
|
|
@ -504,9 +504,6 @@ async def _rebuild_knowledge_from_chunks(
|
||||||
# Re-raise the exception to notify the caller
|
# Re-raise the exception to notify the caller
|
||||||
raise task.exception()
|
raise task.exception()
|
||||||
|
|
||||||
# If all tasks completed successfully, collect results
|
|
||||||
# (No need to collect results since these tasks don't return values)
|
|
||||||
|
|
||||||
# Final status report
|
# Final status report
|
||||||
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
|
status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully."
|
||||||
if failed_entities_count > 0 or failed_relationships_count > 0:
|
if failed_entities_count > 0 or failed_relationships_count > 0:
|
||||||
|
|
@ -1024,6 +1021,7 @@ async def _merge_edges_then_upsert(
|
||||||
pipeline_status: dict = None,
|
pipeline_status: dict = None,
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
added_entities: list = None, # New parameter to track entities added during edge processing
|
||||||
):
|
):
|
||||||
if src_id == tgt_id:
|
if src_id == tgt_id:
|
||||||
return None
|
return None
|
||||||
|
|
@ -1105,17 +1103,27 @@ async def _merge_edges_then_upsert(
|
||||||
|
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
await knowledge_graph_inst.upsert_node(
|
node_data = {
|
||||||
need_insert_id,
|
"entity_id": need_insert_id,
|
||||||
node_data={
|
"source_id": source_id,
|
||||||
"entity_id": need_insert_id,
|
"description": description,
|
||||||
"source_id": source_id,
|
"entity_type": "UNKNOWN",
|
||||||
"description": description,
|
"file_path": file_path,
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data)
|
||||||
|
|
||||||
|
# Track entities added during edge processing
|
||||||
|
if added_entities is not None:
|
||||||
|
entity_data = {
|
||||||
|
"entity_name": need_insert_id,
|
||||||
"entity_type": "UNKNOWN",
|
"entity_type": "UNKNOWN",
|
||||||
|
"description": description,
|
||||||
|
"source_id": source_id,
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
},
|
}
|
||||||
)
|
added_entities.append(entity_data)
|
||||||
|
|
||||||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||||||
|
|
||||||
|
|
@ -1178,6 +1186,9 @@ async def merge_nodes_and_edges(
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
|
full_entities_storage: BaseKVStorage = None,
|
||||||
|
full_relations_storage: BaseKVStorage = None,
|
||||||
|
doc_id: str = None,
|
||||||
pipeline_status: dict = None,
|
pipeline_status: dict = None,
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
|
@ -1185,7 +1196,12 @@ async def merge_nodes_and_edges(
|
||||||
total_files: int = 0,
|
total_files: int = 0,
|
||||||
file_path: str = "unknown_source",
|
file_path: str = "unknown_source",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merge nodes and edges from extraction results
|
"""Two-phase merge: process all entities first, then all relationships
|
||||||
|
|
||||||
|
This approach ensures data consistency by:
|
||||||
|
1. Phase 1: Process all entities concurrently
|
||||||
|
2. Phase 2: Process all relationships concurrently (may add missing entities)
|
||||||
|
3. Phase 3: Update full_entities and full_relations storage with final results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships
|
||||||
|
|
@ -1193,9 +1209,15 @@ async def merge_nodes_and_edges(
|
||||||
entity_vdb: Entity vector database
|
entity_vdb: Entity vector database
|
||||||
relationships_vdb: Relationship vector database
|
relationships_vdb: Relationship vector database
|
||||||
global_config: Global configuration
|
global_config: Global configuration
|
||||||
|
full_entities_storage: Storage for document entity lists
|
||||||
|
full_relations_storage: Storage for document relation lists
|
||||||
|
doc_id: Document ID for storage indexing
|
||||||
pipeline_status: Pipeline status dictionary
|
pipeline_status: Pipeline status dictionary
|
||||||
pipeline_status_lock: Lock for pipeline status
|
pipeline_status_lock: Lock for pipeline status
|
||||||
llm_response_cache: LLM response cache
|
llm_response_cache: LLM response cache
|
||||||
|
current_file_number: Current file number for logging
|
||||||
|
total_files: Total files for logging
|
||||||
|
file_path: File path for logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Collect all nodes and edges from all chunks
|
# Collect all nodes and edges from all chunks
|
||||||
|
|
@ -1212,11 +1234,9 @@ async def merge_nodes_and_edges(
|
||||||
sorted_edge_key = tuple(sorted(edge_key))
|
sorted_edge_key = tuple(sorted(edge_key))
|
||||||
all_edges[sorted_edge_key].extend(edges)
|
all_edges[sorted_edge_key].extend(edges)
|
||||||
|
|
||||||
# Centralized processing of all nodes and edges
|
|
||||||
total_entities_count = len(all_nodes)
|
total_entities_count = len(all_nodes)
|
||||||
total_relations_count = len(all_edges)
|
total_relations_count = len(all_edges)
|
||||||
|
|
||||||
# Merge nodes and edges
|
|
||||||
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
|
|
@ -1227,8 +1247,8 @@ async def merge_nodes_and_edges(
|
||||||
graph_max_async = global_config.get("llm_model_max_async", 4) * 2
|
graph_max_async = global_config.get("llm_model_max_async", 4) * 2
|
||||||
semaphore = asyncio.Semaphore(graph_max_async)
|
semaphore = asyncio.Semaphore(graph_max_async)
|
||||||
|
|
||||||
# Process and update all entities and relationships in parallel
|
# ===== Phase 1: Process all entities concurrently =====
|
||||||
log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations (async: {graph_max_async})"
|
log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
|
|
@ -1263,18 +1283,53 @@ async def merge_nodes_and_edges(
|
||||||
await entity_vdb.upsert(data_for_vdb)
|
await entity_vdb.upsert(data_for_vdb)
|
||||||
return entity_data
|
return entity_data
|
||||||
|
|
||||||
|
# Create entity processing tasks
|
||||||
|
entity_tasks = []
|
||||||
|
for entity_name, entities in all_nodes.items():
|
||||||
|
task = asyncio.create_task(_locked_process_entity_name(entity_name, entities))
|
||||||
|
entity_tasks.append(task)
|
||||||
|
|
||||||
|
# Execute entity tasks with error handling
|
||||||
|
processed_entities = []
|
||||||
|
if entity_tasks:
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if any task raised an exception
|
||||||
|
for task in done:
|
||||||
|
if task.exception():
|
||||||
|
# If a task failed, cancel all pending tasks
|
||||||
|
for pending_task in pending:
|
||||||
|
pending_task.cancel()
|
||||||
|
# Wait for cancellation to complete
|
||||||
|
if pending:
|
||||||
|
await asyncio.wait(pending)
|
||||||
|
# Re-raise the exception to notify the caller
|
||||||
|
raise task.exception()
|
||||||
|
|
||||||
|
# If all tasks completed successfully, collect results
|
||||||
|
processed_entities = [task.result() for task in entity_tasks]
|
||||||
|
|
||||||
|
# ===== Phase 2: Process all relationships concurrently =====
|
||||||
|
log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})"
|
||||||
|
logger.info(log_message)
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
async def _locked_process_edges(edge_key, edges):
|
async def _locked_process_edges(edge_key, edges):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
workspace = global_config.get("workspace", "")
|
workspace = global_config.get("workspace", "")
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
# Sort the edge_key components to ensure consistent lock key generation
|
|
||||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||||
# logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
|
|
||||||
async with get_storage_keyed_lock(
|
async with get_storage_keyed_lock(
|
||||||
sorted_edge_key,
|
sorted_edge_key,
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
enable_logging=False,
|
enable_logging=False,
|
||||||
):
|
):
|
||||||
|
added_entities = [] # Track entities added during edge processing
|
||||||
edge_data = await _merge_edges_then_upsert(
|
edge_data = await _merge_edges_then_upsert(
|
||||||
edge_key[0],
|
edge_key[0],
|
||||||
edge_key[1],
|
edge_key[1],
|
||||||
|
|
@ -1284,9 +1339,11 @@ async def merge_nodes_and_edges(
|
||||||
pipeline_status,
|
pipeline_status,
|
||||||
pipeline_status_lock,
|
pipeline_status_lock,
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
|
added_entities, # Pass list to collect added entities
|
||||||
)
|
)
|
||||||
|
|
||||||
if edge_data is None:
|
if edge_data is None:
|
||||||
return None
|
return None, []
|
||||||
|
|
||||||
if relationships_vdb is not None:
|
if relationships_vdb is not None:
|
||||||
data_for_vdb = {
|
data_for_vdb = {
|
||||||
|
|
@ -1303,50 +1360,106 @@ async def merge_nodes_and_edges(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
return edge_data
|
return edge_data, added_entities
|
||||||
|
|
||||||
# Create a single task queue for both entities and edges
|
# Create relationship processing tasks
|
||||||
tasks = []
|
edge_tasks = []
|
||||||
|
for edge_key, edges in all_edges.items():
|
||||||
|
task = asyncio.create_task(_locked_process_edges(edge_key, edges))
|
||||||
|
edge_tasks.append(task)
|
||||||
|
|
||||||
# Add entity processing tasks
|
# Execute relationship tasks with error handling
|
||||||
for entity_name, entities in all_nodes.items():
|
processed_edges = []
|
||||||
tasks.append(
|
all_added_entities = []
|
||||||
asyncio.create_task(_locked_process_entity_name(entity_name, entities))
|
|
||||||
|
if edge_tasks:
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add edge processing tasks
|
# Check if any task raised an exception
|
||||||
for edge_key, edges in all_edges.items():
|
for task in done:
|
||||||
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
|
if task.exception():
|
||||||
|
# If a task failed, cancel all pending tasks
|
||||||
|
for pending_task in pending:
|
||||||
|
pending_task.cancel()
|
||||||
|
# Wait for cancellation to complete
|
||||||
|
if pending:
|
||||||
|
await asyncio.wait(pending)
|
||||||
|
# Re-raise the exception to notify the caller
|
||||||
|
raise task.exception()
|
||||||
|
|
||||||
# Check if there are any tasks to process
|
# If all tasks completed successfully, collect results
|
||||||
if not tasks:
|
for task in edge_tasks:
|
||||||
log_message = f"No entities or relationships to process for {file_path}"
|
edge_data, added_entities = task.result()
|
||||||
logger.info(log_message)
|
if edge_data is not None:
|
||||||
if pipeline_status_lock is not None:
|
processed_edges.append(edge_data)
|
||||||
async with pipeline_status_lock:
|
all_added_entities.extend(added_entities)
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Execute all tasks in parallel with semaphore control and early failure detection
|
# ===== Phase 3: Update full_entities and full_relations storage =====
|
||||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
if full_entities_storage and full_relations_storage and doc_id:
|
||||||
|
try:
|
||||||
|
# Merge all entities: original entities + entities added during edge processing
|
||||||
|
final_entity_names = set()
|
||||||
|
|
||||||
# Check if any task raised an exception
|
# Add original processed entities
|
||||||
for task in done:
|
for entity_data in processed_entities:
|
||||||
if task.exception():
|
if entity_data and entity_data.get("entity_name"):
|
||||||
# If a task failed, cancel all pending tasks
|
final_entity_names.add(entity_data["entity_name"])
|
||||||
for pending_task in pending:
|
|
||||||
pending_task.cancel()
|
|
||||||
|
|
||||||
# Wait for cancellation to complete
|
# Add entities that were added during relationship processing
|
||||||
if pending:
|
for added_entity in all_added_entities:
|
||||||
await asyncio.wait(pending)
|
if added_entity and added_entity.get("entity_name"):
|
||||||
|
final_entity_names.add(added_entity["entity_name"])
|
||||||
|
|
||||||
# Re-raise the exception to notify the caller
|
# Collect all relation pairs
|
||||||
raise task.exception()
|
final_relation_pairs = set()
|
||||||
|
for edge_data in processed_edges:
|
||||||
|
if edge_data:
|
||||||
|
src_id = edge_data.get("src_id")
|
||||||
|
tgt_id = edge_data.get("tgt_id")
|
||||||
|
if src_id and tgt_id:
|
||||||
|
relation_pair = tuple(sorted([src_id, tgt_id]))
|
||||||
|
final_relation_pairs.add(relation_pair)
|
||||||
|
|
||||||
# If all tasks completed successfully, collect results
|
# Update storage
|
||||||
# (No need to collect results since these tasks don't return values)
|
if final_entity_names:
|
||||||
|
await full_entities_storage.upsert(
|
||||||
|
{
|
||||||
|
doc_id: {
|
||||||
|
"entity_names": list(final_entity_names),
|
||||||
|
"count": len(final_entity_names),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_relation_pairs:
|
||||||
|
await full_relations_storage.upsert(
|
||||||
|
{
|
||||||
|
doc_id: {
|
||||||
|
"relation_pairs": [
|
||||||
|
list(pair) for pair in final_relation_pairs
|
||||||
|
],
|
||||||
|
"count": len(final_relation_pairs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update entity-relation index for document {doc_id}: {e}"
|
||||||
|
)
|
||||||
|
# Don't raise exception to avoid affecting main flow
|
||||||
|
|
||||||
|
log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} added entities, {len(processed_edges)} relations"
|
||||||
|
logger.info(log_message)
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue