feat: Implement optional neo4j metrics and improve tests [cog-1262] (#556)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced graph analytics now offer detailed metrics—including shortest path lengths, diameter, and clustering coefficients—to provide deeper insights. - Added new functions for creating connected test graphs and validating metrics against predefined ground truth values. - Introduced a new JSON file containing metrics for connected and disconnected graph structures. - **Improvements** - Updated how graphs are projected to consistently use undirected representations, ensuring more accurate and reliable metric calculations. - Streamlined metric consistency checks across different graph processing methods for robust, reliable results. - Simplified testing logic by consolidating metric assertions into a single function call. - **Chores** - Removed unnecessary secret variables from the workflow configuration, potentially affecting access to certain resources. - Updated secret management to include the new `OPENAI_API_KEY`. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
2a167fa1ab
commit
e56d86b410
10 changed files with 229 additions and 187 deletions
|
|
@ -16,13 +16,7 @@ jobs:
|
|||
with:
|
||||
example-location: ./cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py
|
||||
secrets:
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
|
|
|||
|
|
@ -57,4 +57,5 @@ class GraphDBInterface(Protocol):
|
|||
|
||||
@abstractmethod
|
||||
async def get_graph_metrics(self, include_optional):
|
||||
""" "https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -13,6 +13,14 @@ from neo4j.exceptions import Neo4jError
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from .neo4j_metrics_utils import (
|
||||
get_avg_clustering,
|
||||
get_edge_density,
|
||||
get_num_connected_components,
|
||||
get_shortest_path_lengths,
|
||||
get_size_of_connected_components,
|
||||
count_self_loops,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("Neo4jAdapter")
|
||||
|
||||
|
|
@ -543,34 +551,49 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
graph_names = result[0]["graphNames"] if result else []
|
||||
return graph_name in graph_names
|
||||
|
||||
async def project_entire_graph(self, graph_name="myGraph"):
|
||||
"""
|
||||
Projects all node labels and all relationship types into an in-memory GDS graph.
|
||||
"""
|
||||
if await self.graph_exists(graph_name):
|
||||
return
|
||||
|
||||
async def get_node_labels_string(self):
|
||||
node_labels_query = "CALL db.labels() YIELD label RETURN collect(label) AS labels;"
|
||||
node_labels_result = await self.query(node_labels_query)
|
||||
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
||||
|
||||
if not node_labels:
|
||||
raise ValueError("No node labels found in the database")
|
||||
|
||||
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
|
||||
return node_labels_str
|
||||
|
||||
async def get_relationship_labels_string(self):
|
||||
relationship_types_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) AS relationships;"
|
||||
relationship_types_result = await self.query(relationship_types_query)
|
||||
relationship_types = (
|
||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||
)
|
||||
|
||||
if not node_labels or not relationship_types:
|
||||
raise ValueError("No node labels or relationship types found in the database.")
|
||||
if not relationship_types:
|
||||
raise ValueError("No relationship types found in the database.")
|
||||
|
||||
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
|
||||
relationship_types_str = "[" + ", ".join(f"'{rel}'" for rel in relationship_types) + "]"
|
||||
relationship_types_undirected_str = (
|
||||
"{"
|
||||
+ ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
|
||||
+ "}"
|
||||
)
|
||||
return relationship_types_undirected_str
|
||||
|
||||
async def project_entire_graph(self, graph_name="myGraph"):
|
||||
"""
|
||||
Projects all node labels and all relationship types into an undirected in-memory GDS graph.
|
||||
"""
|
||||
if await self.graph_exists(graph_name):
|
||||
return
|
||||
|
||||
node_labels_str = await self.get_node_labels_string()
|
||||
relationship_types_undirected_str = await self.get_relationship_labels_string()
|
||||
|
||||
query = f"""
|
||||
CALL gds.graph.project(
|
||||
'{graph_name}',
|
||||
{node_labels_str},
|
||||
{relationship_types_str}
|
||||
{relationship_types_undirected_str}
|
||||
) YIELD graphName;
|
||||
"""
|
||||
|
||||
|
|
@ -582,74 +605,14 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
await self.query(drop_query)
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False):
|
||||
"""For the definition of these metrics, please refer to
|
||||
https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
|
||||
|
||||
nodes, edges = await self.get_model_independent_graph_data()
|
||||
graph_name = "myGraph"
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
async def _get_edge_density():
|
||||
query = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS num_nodes
|
||||
MATCH ()-[r]->()
|
||||
WITH num_nodes, count(r) AS num_edges
|
||||
RETURN CASE
|
||||
WHEN num_nodes < 2 THEN 0
|
||||
ELSE num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
||||
END AS edge_density;
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0]["edge_density"] if result else 0
|
||||
|
||||
async def _get_num_connected_components():
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN count(DISTINCT componentId) AS num_connected_components;
|
||||
"""
|
||||
|
||||
result = await self.query(query)
|
||||
return result[0]["num_connected_components"] if result else 0
|
||||
|
||||
async def _get_size_of_connected_components():
|
||||
await self.drop_graph(graph_name)
|
||||
await self.project_entire_graph(graph_name)
|
||||
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN componentId, count(*) AS size
|
||||
ORDER BY size DESC;
|
||||
"""
|
||||
|
||||
result = await self.query(query)
|
||||
return [record["size"] for record in result] if result else []
|
||||
|
||||
async def _count_self_loops():
|
||||
query = """
|
||||
MATCH (n)-[r]->(n)
|
||||
RETURN count(r) AS self_loop_count;
|
||||
"""
|
||||
result = await self.query(query)
|
||||
return result[0]["self_loop_count"] if result else 0
|
||||
|
||||
async def _get_diameter():
|
||||
logging.warning("Diameter calculation is not implemented for neo4j.")
|
||||
return -1
|
||||
|
||||
async def _get_avg_shortest_path_length():
|
||||
logging.warning(
|
||||
"Average shortest path length calculation is not implemented for neo4j."
|
||||
)
|
||||
return -1
|
||||
|
||||
async def _get_avg_clustering():
|
||||
logging.warning("Average clustering calculation is not implemented for neo4j.")
|
||||
return -1
|
||||
|
||||
num_nodes = len(nodes[0]["nodes"])
|
||||
num_edges = len(edges[0]["elements"])
|
||||
|
||||
|
|
@ -657,17 +620,22 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes != 0 else None,
|
||||
"edge_density": await _get_edge_density(),
|
||||
"num_connected_components": await _get_num_connected_components(),
|
||||
"sizes_of_connected_components": await _get_size_of_connected_components(),
|
||||
"edge_density": await get_edge_density(self),
|
||||
"num_connected_components": await get_num_connected_components(self, graph_name),
|
||||
"sizes_of_connected_components": await get_size_of_connected_components(
|
||||
self, graph_name
|
||||
),
|
||||
}
|
||||
|
||||
if include_optional:
|
||||
shortest_path_lengths = await get_shortest_path_lengths(self, graph_name)
|
||||
optional_metrics = {
|
||||
"num_selfloops": await _count_self_loops(),
|
||||
"diameter": await _get_diameter(),
|
||||
"avg_shortest_path_length": await _get_avg_shortest_path_length(),
|
||||
"avg_clustering": await _get_avg_clustering(),
|
||||
"num_selfloops": await count_self_loops(self),
|
||||
"diameter": max(shortest_path_lengths) if shortest_path_lengths else -1,
|
||||
"avg_shortest_path_length": sum(shortest_path_lengths) / len(shortest_path_lengths)
|
||||
if shortest_path_lengths
|
||||
else -1,
|
||||
"avg_clustering": await get_avg_clustering(self, graph_name),
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
|
||||
async def get_edge_density(adapter: Neo4jAdapter):
|
||||
query = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS num_nodes
|
||||
MATCH ()-[r]->()
|
||||
WITH num_nodes, count(r) AS num_edges
|
||||
RETURN CASE
|
||||
WHEN num_nodes < 2 THEN 0
|
||||
ELSE num_edges * 1.0 / (num_nodes * (num_nodes - 1))
|
||||
END AS edge_density;
|
||||
"""
|
||||
result = await adapter.query(query)
|
||||
return result[0]["edge_density"] if result else 0
|
||||
|
||||
|
||||
async def get_num_connected_components(adapter: Neo4jAdapter, graph_name: str):
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN count(DISTINCT componentId) AS num_connected_components;
|
||||
"""
|
||||
|
||||
result = await adapter.query(query)
|
||||
return result[0]["num_connected_components"] if result else 0
|
||||
|
||||
|
||||
async def get_size_of_connected_components(adapter: Neo4jAdapter, graph_name: str):
|
||||
query = f"""
|
||||
CALL gds.wcc.stream('{graph_name}')
|
||||
YIELD componentId
|
||||
RETURN componentId, count(*) AS size
|
||||
ORDER BY size DESC;
|
||||
"""
|
||||
|
||||
result = await adapter.query(query)
|
||||
return [record["size"] for record in result] if result else []
|
||||
|
||||
|
||||
async def count_self_loops(adapter: Neo4jAdapter):
|
||||
query = """
|
||||
MATCH (n)-[r]->(n)
|
||||
RETURN count(r) AS adapter_loop_count;
|
||||
"""
|
||||
result = await adapter.query(query)
|
||||
return result[0]["adapter_loop_count"] if result else 0
|
||||
|
||||
|
||||
async def get_shortest_path_lengths(adapter: Neo4jAdapter, graph_name: str):
|
||||
query = f"""
|
||||
CALL gds.allShortestPaths.stream('{graph_name}')
|
||||
YIELD distance
|
||||
RETURN distance;
|
||||
"""
|
||||
|
||||
result = await adapter.query(query)
|
||||
return [res["distance"] for res in result] if result else []
|
||||
|
||||
|
||||
async def get_avg_clustering(adapter: Neo4jAdapter, graph_name: str):
|
||||
query = f"""
|
||||
CALL gds.localClusteringCoefficient.stream('{graph_name}')
|
||||
YIELD localClusteringCoefficient
|
||||
RETURN avg(localClusteringCoefficient) AS avg_clustering;
|
||||
"""
|
||||
|
||||
result = await adapter.query(query)
|
||||
return result[0]["avg_clustering"] if result else 0
|
||||
|
|
@ -423,7 +423,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
def _get_avg_clustering(graph):
|
||||
try:
|
||||
return nx.average_clustering(nx.DiGraph(graph))
|
||||
return nx.average_clustering(nx.DiGraph(graph.to_undirected()))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to calculate clustering coefficient: %s", e)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"connected": {
|
||||
"num_nodes": 5,
|
||||
"num_edges": 6,
|
||||
"mean_degree": 2.4,
|
||||
"edge_density": 0.3,
|
||||
"num_connected_components": 1,
|
||||
"sizes_of_connected_components": [5],
|
||||
"num_selfloops": 1,
|
||||
"diameter": 3,
|
||||
"avg_shortest_path_length": 1.6,
|
||||
"avg_clustering": 0
|
||||
},
|
||||
"disconnected": {
|
||||
"num_nodes": 9,
|
||||
"num_edges": 8,
|
||||
"mean_degree": 1.7777777777777777,
|
||||
"edge_density": 0.1111111111111111,
|
||||
"num_connected_components": 2,
|
||||
"sizes_of_connected_components": [5, 4],
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +1,22 @@
|
|||
from cognee.tests.tasks.descriptive_metrics.networkx_metrics_test import get_networkx_metrics
|
||||
from cognee.tests.tasks.descriptive_metrics.neo4j_metrics_test import get_neo4j_metrics
|
||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import get_metrics
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
async def check_graph_metrics_consistency_across_adapters():
|
||||
neo4j_metrics = await get_neo4j_metrics(include_optional=False)
|
||||
networkx_metrics = await get_networkx_metrics(include_optional=False)
|
||||
assert networkx_metrics == neo4j_metrics
|
||||
async def check_graph_metrics_consistency_across_adapters(include_optional=False):
|
||||
neo4j_metrics = await get_metrics(provider="neo4j", include_optional=include_optional)
|
||||
networkx_metrics = await get_metrics(provider="networkx", include_optional=include_optional)
|
||||
|
||||
diff_keys = set(neo4j_metrics.keys()).symmetric_difference(set(networkx_metrics.keys()))
|
||||
if diff_keys:
|
||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||
|
||||
for key, neo4j_value in neo4j_metrics.items():
|
||||
assert networkx_metrics[key] == neo4j_value, (
|
||||
f"Difference in '{key}': got {neo4j_value} with neo4j and {networkx_metrics[key]} with networkx"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_graph_metrics_consistency_across_adapters())
|
||||
asyncio.run(check_graph_metrics_consistency_across_adapters(include_optional=True))
|
||||
asyncio.run(check_graph_metrics_consistency_across_adapters(include_optional=False))
|
||||
|
|
|
|||
|
|
@ -5,6 +5,11 @@ from cognee.tests.unit.interfaces.graph.get_graph_from_model_test import (
|
|||
EntityType,
|
||||
)
|
||||
from cognee.tasks.storage.add_data_points import add_data_points
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def create_disconnected_test_graph():
|
||||
|
|
@ -13,8 +18,7 @@ async def create_disconnected_test_graph():
|
|||
entity_type = EntityType(name="Person")
|
||||
entity = Entity(name="Alice", is_type=entity_type)
|
||||
entity2 = Entity(name="Alice2", is_type=entity_type)
|
||||
# the following self-loop is intentional and serves the purpose of testing the self-loop counting functionality
|
||||
doc_chunk.contains.extend([entity, entity2, doc_chunk])
|
||||
doc_chunk.contains.extend([entity, entity2])
|
||||
|
||||
doc2 = Document(path="test/path2")
|
||||
doc_chunk2 = DocumentChunk(part_of=doc2, text="This is a chunk of text", contains=[])
|
||||
|
|
@ -23,3 +27,50 @@ async def create_disconnected_test_graph():
|
|||
doc_chunk2.contains.extend([entity3])
|
||||
|
||||
await add_data_points([doc_chunk, doc_chunk2])
|
||||
|
||||
|
||||
async def create_connected_test_graph():
|
||||
doc = Document(path="test/path")
|
||||
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
|
||||
entity_type = EntityType(name="Person")
|
||||
entity = Entity(name="Alice", is_type=entity_type)
|
||||
entity2 = Entity(name="Alice2", is_type=entity_type)
|
||||
# the following self-loop is intentional and serves the purpose of testing the self-loop counting functionality
|
||||
doc_chunk.contains.extend([entity, entity2, doc_chunk])
|
||||
|
||||
await add_data_points([doc_chunk])
|
||||
|
||||
|
||||
async def get_metrics(provider: str, include_optional=True):
|
||||
create_graph_engine.cache_clear()
|
||||
cognee.config.set_graph_database_provider(provider)
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
if include_optional:
|
||||
await create_connected_test_graph()
|
||||
else:
|
||||
await create_disconnected_test_graph()
|
||||
graph_metrics = await graph_engine.get_graph_metrics(include_optional=include_optional)
|
||||
return graph_metrics
|
||||
|
||||
|
||||
async def assert_metrics(provider, include_optional=True):
|
||||
metrics = await get_metrics(provider=provider, include_optional=include_optional)
|
||||
|
||||
gt_path = Path(__file__).parent / "ground_truth_metrics.json"
|
||||
with open(gt_path, "r") as file:
|
||||
ground_truth_metrics = json.load(file)
|
||||
|
||||
if include_optional:
|
||||
ground_truth_metrics = ground_truth_metrics["connected"]
|
||||
else:
|
||||
ground_truth_metrics = ground_truth_metrics["disconnected"]
|
||||
|
||||
diff_keys = set(metrics.keys()).symmetric_difference(set(ground_truth_metrics.keys()))
|
||||
if diff_keys:
|
||||
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
||||
|
||||
for key, ground_truth_value in ground_truth_metrics.items():
|
||||
assert metrics[key] == ground_truth_value, (
|
||||
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,42 +1,7 @@
|
|||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_disconnected_test_graph
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
import cognee
|
||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import assert_metrics
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
|
||||
async def get_neo4j_metrics(include_optional=True):
|
||||
create_graph_engine.cache_clear()
|
||||
cognee.config.set_graph_database_provider("neo4j")
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
await create_disconnected_test_graph()
|
||||
neo4j_graph_metrics = await graph_engine.get_graph_metrics(include_optional=include_optional)
|
||||
return neo4j_graph_metrics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_neo4j_metrics():
|
||||
neo4j_metrics = await get_neo4j_metrics(include_optional=True)
|
||||
assert neo4j_metrics["num_nodes"] == 9, f"Expected 9 nodes, got {neo4j_metrics['num_nodes']}"
|
||||
assert neo4j_metrics["num_edges"] == 9, f"Expected 9 edges, got {neo4j_metrics['num_edges']}"
|
||||
assert neo4j_metrics["mean_degree"] == 2, (
|
||||
f"Expected mean degree is 2, got {neo4j_metrics['mean_degree']}"
|
||||
)
|
||||
assert neo4j_metrics["edge_density"] == 0.125, (
|
||||
f"Expected edge density is 0.125, got {neo4j_metrics['edge_density']}"
|
||||
)
|
||||
assert neo4j_metrics["num_connected_components"] == 2, (
|
||||
f"Expected 2 connected components, got {neo4j_metrics['num_connected_components']}"
|
||||
)
|
||||
assert neo4j_metrics["sizes_of_connected_components"] == [5, 4], (
|
||||
f"Expected connected components of size [5, 4], got {neo4j_metrics['sizes_of_connected_components']}"
|
||||
)
|
||||
assert neo4j_metrics["num_selfloops"] == 1, (
|
||||
f"Expected 1 self-loop, got {neo4j_metrics['num_selfloops']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_neo4j_metrics())
|
||||
asyncio.run(assert_metrics(provider="neo4j", include_optional=False))
|
||||
asyncio.run(assert_metrics(provider="neo4j", include_optional=True))
|
||||
|
|
|
|||
|
|
@ -1,53 +1,7 @@
|
|||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import create_disconnected_test_graph
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
import cognee
|
||||
from cognee.tests.tasks.descriptive_metrics.metrics_test_utils import assert_metrics
|
||||
import asyncio
|
||||
|
||||
|
||||
async def get_networkx_metrics(include_optional=True):
|
||||
create_graph_engine.cache_clear()
|
||||
cognee.config.set_graph_database_provider("networkx")
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
await create_disconnected_test_graph()
|
||||
networkx_graph_metrics = await graph_engine.get_graph_metrics(include_optional=include_optional)
|
||||
return networkx_graph_metrics
|
||||
|
||||
|
||||
async def assert_networkx_metrics():
|
||||
networkx_metrics = await get_networkx_metrics(include_optional=True)
|
||||
assert networkx_metrics["num_nodes"] == 9, (
|
||||
f"Expected 9 nodes, got {networkx_metrics['num_nodes']}"
|
||||
)
|
||||
assert networkx_metrics["num_edges"] == 9, (
|
||||
f"Expected 9 edges, got {networkx_metrics['num_edges']}"
|
||||
)
|
||||
assert networkx_metrics["mean_degree"] == 2, (
|
||||
f"Expected mean degree is 2, got {networkx_metrics['mean_degree']}"
|
||||
)
|
||||
assert networkx_metrics["edge_density"] == 0.125, (
|
||||
f"Expected edge density is 0.125, got {networkx_metrics['edge_density']}"
|
||||
)
|
||||
assert networkx_metrics["num_connected_components"] == 2, (
|
||||
f"Expected 2 connected components, got {networkx_metrics['num_connected_components']}"
|
||||
)
|
||||
assert networkx_metrics["sizes_of_connected_components"] == [5, 4], (
|
||||
f"Expected connected components of size [5, 4], got {networkx_metrics['sizes_of_connected_components']}"
|
||||
)
|
||||
assert networkx_metrics["num_selfloops"] == 1, (
|
||||
f"Expected 1 self-loop, got {networkx_metrics['num_selfloops']}"
|
||||
)
|
||||
assert networkx_metrics["diameter"] is None, (
|
||||
f"Diameter should be None for disconnected graphs, got {networkx_metrics['diameter']}"
|
||||
)
|
||||
assert networkx_metrics["avg_shortest_path_length"] is None, (
|
||||
f"Average shortest path should be None for disconnected graphs, got {networkx_metrics['avg_shortest_path_length']}"
|
||||
)
|
||||
assert networkx_metrics["avg_clustering"] == 0, (
|
||||
f"Expected 0 average clustering, got {networkx_metrics['avg_clustering']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(assert_networkx_metrics())
|
||||
asyncio.run(assert_metrics(provider="networkx", include_optional=False))
|
||||
asyncio.run(assert_metrics(provider="networkx", include_optional=True))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue