fix: handle queue error

This commit is contained in:
Boris Arzentar 2025-07-07 13:54:22 +02:00
parent fa5ea44345
commit 68adf6877b
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
7 changed files with 48 additions and 29 deletions

View file

@ -296,13 +296,13 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=OntologyResolver(ontology_file=ontology_file_path),
task_config={"batch_size": 25},
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
task_config={"batch_size": 25},
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 25}),
Task(add_data_points, task_config={"batch_size": 10}),
]
return default_tasks

View file

@ -798,10 +798,20 @@ class Neo4jAdapter(GraphDBInterface):
The result of the query execution, typically indicating success or failure.
"""
query = """MATCH (node)
DETACH DELETE node;"""
# query = """MATCH (node)
# DETACH DELETE node;"""
return await self.query(query)
# return await self.query(query)
node_labels = await self.get_node_labels()
for label in node_labels:
query = f"""
MATCH (node:`{label}`)
DETACH DELETE node;
"""
await self.query(query)
def serialize_properties(self, properties=dict()):
"""
@ -1031,24 +1041,20 @@ class Neo4jAdapter(GraphDBInterface):
graph_names = result[0]["graphNames"] if result else []
return graph_name in graph_names
async def get_node_labels_string(self):
async def get_node_labels(self):
"""
Fetch all node labels from the database and return them as a formatted string.
Fetch all node labels from the database and return them.
Returns:
--------
A formatted string of node labels.
A list of node labels.
"""
node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;"
node_labels_query = "CALL db.labels()"
node_labels_result = await self.query(node_labels_query)
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
node_labels = [record["label"] for record in node_labels_result]
if not node_labels:
raise ValueError("No node labels found in the database")
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
return node_labels_str
return node_labels
async def get_relationship_labels_string(self):
"""
@ -1088,13 +1094,13 @@ class Neo4jAdapter(GraphDBInterface):
if await self.graph_exists(graph_name):
return
node_labels_str = await self.get_node_labels_string()
node_labels = await self.get_node_labels()
relationship_types_undirected_str = await self.get_relationship_labels_string()
query = f"""
CALL gds.graph.project(
'{graph_name}',
{node_labels_str},
['{"', '".join(node_labels)}'],
{relationship_types_undirected_str}
) YIELD graphName;
"""

View file

@ -5,6 +5,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
@ -113,8 +114,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
return False
@retry(
retry=retry_if_exception_type((DuplicateTableError, UniqueViolationError)),
stop=stop_after_attempt(3),
retry=retry_if_exception_type(
(DuplicateTableError, UniqueViolationError, ProgrammingError)
),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):

View file

@ -21,6 +21,7 @@ os.environ["COGNEE_DISTRIBUTED"] = "True"
async def main():
# Clear queues
await add_nodes_and_edges_queue.clear.aio()
await add_data_points_queue.clear.aio()
number_of_graph_saving_workers = 1 # Total number of graph_saving_worker to spawn
number_of_data_point_saving_workers = 5 # Total number of graph_saving_worker to spawn

View file

@ -1,12 +1,15 @@
from grpclib import GRPCError
async def queued_add_data_points(collection_name, data_points_batch):
from ..queues import add_data_points_queue
try:
await add_data_points_queue.put.aio((collection_name, data_points_batch))
except Exception:
except GRPCError:
first_half, second_half = (
data_points_batch[: len(data_points_batch) // 2],
data_points_batch[len(data_points_batch) // 2 :],
)
await add_data_points_queue.put.aio((collection_name, first_half))
await add_data_points_queue.put.aio((collection_name, second_half))
await queued_add_data_points(collection_name, first_half)
await queued_add_data_points(collection_name, second_half)

View file

@ -1,12 +1,15 @@
from grpclib import GRPCError
async def queued_add_edges(edge_batch):
from ..queues import add_nodes_and_edges_queue
try:
await add_nodes_and_edges_queue.put.aio(([], edge_batch))
except Exception:
except GRPCError:
first_half, second_half = (
edge_batch[: len(edge_batch) // 2],
edge_batch[len(edge_batch) // 2 :],
)
await add_nodes_and_edges_queue.put.aio(([], first_half))
await add_nodes_and_edges_queue.put.aio(([], second_half))
await queued_add_edges(first_half)
await queued_add_edges(second_half)

View file

@ -1,12 +1,15 @@
from grpclib import GRPCError
async def queued_add_nodes(node_batch):
from ..queues import add_nodes_and_edges_queue
try:
await add_nodes_and_edges_queue.put.aio((node_batch, []))
except Exception:
except GRPCError:
first_half, second_half = (
node_batch[: len(node_batch) // 2],
node_batch[len(node_batch) // 2 :],
)
await add_nodes_and_edges_queue.put.aio((first_half, []))
await add_nodes_and_edges_queue.put.aio((second_half, []))
await queued_add_nodes(first_half)
await queued_add_nodes(second_half)