feat: metrics in neo4j adapter [COG-1082] (#487)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced graph management capabilities allow users to verify graph
existence, project complete graphs, and remove graphs, delivering more
comprehensive graph insights.
  
- **Refactor**
  - Adjusted default task behavior for streamlined performance.
- Updated timestamp handling to ensure accurate and consistent record
tracking.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
alekszievr 2025-02-07 15:58:43 +01:00 committed by GitHub
parent 0c42c10f64
commit 8396fed9a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 141 additions and 14 deletions

View file

@ -165,7 +165,6 @@ async def get_default_tasks(
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
Task(store_descriptive_metrics, include_optional=True),
]
except Exception as error:
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)

View file

@ -531,16 +531,144 @@ class Neo4jAdapter(GraphDBInterface):
return (nodes, edges)
async def graph_exists(self, graph_name="myGraph"):
query = "CALL gds.graph.list() YIELD graphName RETURN collect(graphName) AS graphNames;"
result = await self.query(query)
graph_names = result[0]["graphNames"] if result else []
return graph_name in graph_names
async def project_entire_graph(self, graph_name="myGraph"):
"""
Projects all node labels and all relationship types into an in-memory GDS graph.
"""
if await self.graph_exists(graph_name):
return
node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;"
node_labels_result = await self.query(node_labels_query)
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;"
relationship_types_result = await self.query(relationship_types_query)
relationship_types = (
relationship_types_result[0]["relationships"] if relationship_types_result else []
)
if not node_labels or not relationship_types:
raise ValueError("No node labels or relationship types found in the database.")
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
relationship_types_str = "[" + ", ".join(f"'{rel}'" for rel in relationship_types) + "]"
query = f"""
CALL gds.graph.project(
'{graph_name}',
{node_labels_str},
{relationship_types_str}
) YIELD graphName;
"""
await self.query(query)
async def drop_graph(self, graph_name="myGraph"):
if await self.graph_exists(graph_name):
drop_query = f"CALL gds.graph.drop('{graph_name}');"
await self.query(drop_query)
async def get_graph_metrics(self, include_optional=False):
return {
"num_nodes": -1,
"num_edges": -1,
"mean_degree": -1,
"edge_density": -1,
"num_connected_components": -1,
"sizes_of_connected_components": -1,
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
nodes, edges = await self.get_model_independent_graph_data()
graph_name = "myGraph"
await self.drop_graph(graph_name)
await self.project_entire_graph(graph_name)
async def _get_edge_density():
query = """
MATCH (n)
WITH count(n) AS num_nodes
MATCH ()-[r]->()
WITH num_nodes, count(r) AS num_edges
RETURN CASE
WHEN num_nodes < 2 THEN 0
ELSE num_edges * 1.0 / (num_nodes * (num_nodes - 1))
END AS edge_density;
"""
result = await self.query(query)
return result[0]["edge_density"] if result else 0
async def _get_num_connected_components():
await self.drop_graph(graph_name)
await self.project_entire_graph(graph_name)
query = f"""
CALL gds.wcc.stream('{graph_name}')
YIELD componentId
RETURN count(DISTINCT componentId) AS num_connected_components;
"""
result = await self.query(query)
return result[0]["num_connected_components"] if result else 0
async def _get_size_of_connected_components():
await self.drop_graph(graph_name)
await self.project_entire_graph(graph_name)
query = f"""
CALL gds.wcc.stream('{graph_name}')
YIELD componentId
RETURN componentId, count(*) AS size
ORDER BY size DESC;
"""
result = await self.query(query)
return [record["size"] for record in result] if result else []
async def _count_self_loops():
query = """
MATCH (n)-[r]->(n)
RETURN count(r) AS self_loop_count;
"""
result = await self.query(query)
return result[0]["self_loop_count"] if result else 0
async def _get_diameter():
logging.warning("Diameter calculation is not implemented for neo4j.")
return -1
async def _get_avg_shortest_path_length():
logging.warning(
"Average shortest path length calculation is not implemented for neo4j."
)
return -1
async def _get_avg_clustering():
logging.warning("Average clustering calculation is not implemented for neo4j.")
return -1
num_nodes = len(nodes[0]["nodes"])
num_edges = len(edges[0]["elements"])
mandatory_metrics = {
"num_nodes": num_nodes,
"num_edges": num_edges,
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
"edge_density": await _get_edge_density(),
"num_connected_components": await _get_num_connected_components(),
"sizes_of_connected_components": await _get_size_of_connected_components(),
}
if include_optional:
optional_metrics = {
"num_selfloops": await _count_self_loops(),
"diameter": await _get_diameter(),
"avg_shortest_path_length": await _get_avg_shortest_path_length(),
"avg_clustering": await _get_avg_clustering(),
}
else:
optional_metrics = {
"num_selfloops": -1,
"diameter": -1,
"avg_shortest_path_length": -1,
"avg_clustering": -1,
}
return mandatory_metrics | optional_metrics

View file

@ -24,5 +24,5 @@ class GraphMetrics(Base):
avg_shortest_path_length = Column(Float, nullable=True)
avg_clustering = Column(Float, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())