cognee/cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py
alekszievr e56d86b410
feat: Implement optional neo4j metrics and improve tests [cog-1262] (#556)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced graph analytics now offer detailed metrics—including shortest
path lengths, diameter, and clustering coefficients—to provide deeper
insights.
- Added new functions for creating connected test graphs and validating
metrics against predefined ground truth values.
- Introduced a new JSON file containing metrics for connected and
disconnected graph structures.

- **Improvements**
- Updated how graphs are projected to consistently use undirected
representations, ensuring more accurate and reliable metric
calculations.
- Streamlined metric consistency checks across different graph
processing methods for robust, reliable results.
- Simplified testing logic by consolidating metric assertions into a
single function call.

- **Chores**
- Removed unnecessary secret variables from the workflow configuration,
potentially affecting access to certain resources.
	- Updated secret management to include the new `OPENAI_API_KEY`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-02-19 16:24:59 +01:00

76 lines
2.9 KiB
Python

from cognee.tests.unit.interfaces.graph.get_graph_from_model_test import (
Document,
DocumentChunk,
Entity,
EntityType,
)
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
import json
from pathlib import Path
async def create_disconnected_test_graph():
doc = Document(path="test/path")
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
entity_type = EntityType(name="Person")
entity = Entity(name="Alice", is_type=entity_type)
entity2 = Entity(name="Alice2", is_type=entity_type)
doc_chunk.contains.extend([entity, entity2])
doc2 = Document(path="test/path2")
doc_chunk2 = DocumentChunk(part_of=doc2, text="This is a chunk of text", contains=[])
entity_type2 = EntityType(name="Person")
entity3 = Entity(name="Bob", is_type=entity_type2)
doc_chunk2.contains.extend([entity3])
await add_data_points([doc_chunk, doc_chunk2])
async def create_connected_test_graph():
doc = Document(path="test/path")
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
entity_type = EntityType(name="Person")
entity = Entity(name="Alice", is_type=entity_type)
entity2 = Entity(name="Alice2", is_type=entity_type)
# the following self-loop is intentional and serves the purpose of testing the self-loop counting functionality
doc_chunk.contains.extend([entity, entity2, doc_chunk])
await add_data_points([doc_chunk])
async def get_metrics(provider: str, include_optional=True):
create_graph_engine.cache_clear()
cognee.config.set_graph_database_provider(provider)
graph_engine = await get_graph_engine()
await graph_engine.delete_graph()
if include_optional:
await create_connected_test_graph()
else:
await create_disconnected_test_graph()
graph_metrics = await graph_engine.get_graph_metrics(include_optional=include_optional)
return graph_metrics
async def assert_metrics(provider, include_optional=True):
metrics = await get_metrics(provider=provider, include_optional=include_optional)
gt_path = Path(__file__).parent / "ground_truth_metrics.json"
with open(gt_path, "r") as file:
ground_truth_metrics = json.load(file)
if include_optional:
ground_truth_metrics = ground_truth_metrics["connected"]
else:
ground_truth_metrics = ground_truth_metrics["disconnected"]
diff_keys = set(metrics.keys()).symmetric_difference(set(ground_truth_metrics.keys()))
if diff_keys:
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
for key, ground_truth_value in ground_truth_metrics.items():
assert metrics[key] == ground_truth_value, (
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
)