feat: metrics in neo4j adapter [COG-1082] (#487)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced graph management capabilities allow users to verify graph existence, project complete graphs, and remove graphs, delivering more comprehensive graph insights. - **Refactor** - Adjusted default task behavior for streamlined performance. - Updated timestamp handling to ensure accurate and consistent record tracking. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
parent
0c42c10f64
commit
8396fed9a1
3 changed files with 141 additions and 14 deletions
|
|
@ -165,7 +165,6 @@ async def get_default_tasks(
|
|||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
Task(store_descriptive_metrics, include_optional=True),
|
||||
]
|
||||
except Exception as error:
|
||||
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
|
||||
|
|
|
|||
|
|
@ -531,16 +531,144 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return (nodes, edges)
|
||||
|
||||
async def graph_exists(self, graph_name="myGraph"):
|
||||
query = "CALL gds.graph.list() YIELD graphName RETURN collect(graphName) AS graphNames;"
|
||||
result = await self.query(query)
|
||||
graph_names = result[0]["graphNames"] if result else []
|
||||
return graph_name in graph_names
|
||||
|
||||
async def project_entire_graph(self, graph_name="myGraph"):
|
||||
"""
|
||||
Projects all node labels and all relationship types into an in-memory GDS graph.
|
||||
"""
|
||||
if await self.graph_exists(graph_name):
|
||||
return
|
||||
|
||||
node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;"
|
||||
node_labels_result = await self.query(node_labels_query)
|
||||
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
||||
|
||||
relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;"
|
||||
relationship_types_result = await self.query(relationship_types_query)
|
||||
relationship_types = (
|
||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||
)
|
||||
|
||||
if not node_labels or not relationship_types:
|
||||
raise ValueError("No node labels or relationship types found in the database.")
|
||||
|
||||
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
|
||||
relationship_types_str = "[" + ", ".join(f"'{rel}'" for rel in relationship_types) + "]"
|
||||
|
||||
query = f"""
|
||||
CALL gds.graph.project(
|
||||
'{graph_name}',
|
||||
{node_labels_str},
|
||||
{relationship_types_str}
|
||||
) YIELD graphName;
|
||||
"""
|
||||
|
||||
await self.query(query)
|
||||
|
||||
async def drop_graph(self, graph_name="myGraph"):
|
||||
if await self.graph_exists(graph_name):
|
||||
drop_query = f"CALL gds.graph.drop('{graph_name}');"
|
||||
await self.query(drop_query)
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False):
|
||||
return {
|
||||
"num_nodes": -1,
|
||||
"num_edges": -1,
|
||||
"mean_degree": -1,
|
||||
"edge_density": -1,
|
||||
"num_connected_components": -1,
|
||||
"sizes_of_connected_components": -1,
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
nodes, edges = await self.get_model_independent_graph_data()
|
||||
graph_name = "myGraph"
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
async def _get_edge_density():
|
||||
query = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS num_nodes
|
||||
MATCH ()-[r]->()
|
||||
WITH num_nodes, count(r) AS num_edges
|
||||
RETURN CASE
|
||||
WHEN num_nodes < 2 THEN 0
|
||||
ELSE num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
||||
END AS edge_density;
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0]["edge_density"] if result else 0
|
||||
|
||||
async def _get_num_connected_components():
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN count(DISTINCT componentId) AS num_connected_components;
|
||||
"""
|
||||
|
||||
result = await self.query(query)
|
||||
return result[0]["num_connected_components"] if result else 0
|
||||
|
||||
async def _get_size_of_connected_components():
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN componentId, count(*) AS size
|
||||
ORDER BY size DESC;
|
||||
"""
|
||||
|
||||
result = await self.query(query)
|
||||
return [record["size"] for record in result] if result else []
|
||||
|
||||
async def _count_self_loops():
|
||||
query = """
|
||||
MATCH (n)-[r]->(n)
|
||||
RETURN count(r) AS self_loop_count;
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0]["self_loop_count"] if result else 0
|
||||
|
||||
async def _get_diameter():
|
||||
logging.warning("Diameter calculation is not implemented for neo4j.")
|
||||
return -1
|
||||
|
||||
async def _get_avg_shortest_path_length():
|
||||
logging.warning(
|
||||
"Average shortest path length calculation is not implemented for neo4j."
|
||||
)
|
||||
return -1
|
||||
|
||||
async def _get_avg_clustering():
|
||||
logging.warning("Average clustering calculation is not implemented for neo4j.")
|
||||
return -1
|
||||
|
||||
num_nodes = len(nodes[0]["nodes"])
|
||||
num_edges = len(edges[0]["elements"])
|
||||
|
||||
mandatory_metrics = {
|
||||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||||
"edge_density": await _get_edge_density(),
|
||||
"num_connected_components": await _get_num_connected_components(),
|
||||
"sizes_of_connected_components": await _get_size_of_connected_components(),
|
||||
}
|
||||
|
||||
if include_optional:
|
||||
optional_metrics = {
|
||||
"num_selfloops": await _count_self_loops(),
|
||||
"diameter": await _get_diameter(),
|
||||
"avg_shortest_path_length": await _get_avg_shortest_path_length(),
|
||||
"avg_clustering": await _get_avg_clustering(),
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
return mandatory_metrics | optional_metrics
|
||||
|
|
|
|||
|
|
@ -24,5 +24,5 @@ class GraphMetrics(Base):
|
|||
avg_shortest_path_length = Column(Float, nullable=True)
|
||||
avg_clustering = Column(Float, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue