<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced graph analytics now offer detailed metrics—including shortest path lengths, diameter, and clustering coefficients—to provide deeper insights. - Added new functions for creating connected test graphs and validating metrics against predefined ground truth values. - Introduced a new JSON file containing metrics for connected and disconnected graph structures. - **Improvements** - Updated how graphs are projected to consistently use undirected representations, ensuring more accurate and reliable metric calculations. - Streamlined metric consistency checks across different graph processing methods for robust, reliable results. - Simplified testing logic by consolidating metric assertions into a single function call. - **Chores** - Removed unnecessary secret variables from the workflow configuration, potentially affecting access to certain resources. - Updated secret management to include the new `OPENAI_API_KEY`. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
from cognee.tests.unit.interfaces.graph.get_graph_from_model_test import (
|
|
Document,
|
|
DocumentChunk,
|
|
Entity,
|
|
EntityType,
|
|
)
|
|
from cognee.tasks.storage.add_data_points import add_data_points
|
|
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
|
import cognee
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
import json
|
|
from pathlib import Path
|
|
|
|
|
|
async def create_disconnected_test_graph():
|
|
doc = Document(path="test/path")
|
|
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
|
|
entity_type = EntityType(name="Person")
|
|
entity = Entity(name="Alice", is_type=entity_type)
|
|
entity2 = Entity(name="Alice2", is_type=entity_type)
|
|
doc_chunk.contains.extend([entity, entity2])
|
|
|
|
doc2 = Document(path="test/path2")
|
|
doc_chunk2 = DocumentChunk(part_of=doc2, text="This is a chunk of text", contains=[])
|
|
entity_type2 = EntityType(name="Person")
|
|
entity3 = Entity(name="Bob", is_type=entity_type2)
|
|
doc_chunk2.contains.extend([entity3])
|
|
|
|
await add_data_points([doc_chunk, doc_chunk2])
|
|
|
|
|
|
async def create_connected_test_graph():
|
|
doc = Document(path="test/path")
|
|
doc_chunk = DocumentChunk(part_of=doc, text="This is a chunk of text", contains=[])
|
|
entity_type = EntityType(name="Person")
|
|
entity = Entity(name="Alice", is_type=entity_type)
|
|
entity2 = Entity(name="Alice2", is_type=entity_type)
|
|
# the following self-loop is intentional and serves the purpose of testing the self-loop counting functionality
|
|
doc_chunk.contains.extend([entity, entity2, doc_chunk])
|
|
|
|
await add_data_points([doc_chunk])
|
|
|
|
|
|
async def get_metrics(provider: str, include_optional=True):
|
|
create_graph_engine.cache_clear()
|
|
cognee.config.set_graph_database_provider(provider)
|
|
graph_engine = await get_graph_engine()
|
|
await graph_engine.delete_graph()
|
|
if include_optional:
|
|
await create_connected_test_graph()
|
|
else:
|
|
await create_disconnected_test_graph()
|
|
graph_metrics = await graph_engine.get_graph_metrics(include_optional=include_optional)
|
|
return graph_metrics
|
|
|
|
|
|
async def assert_metrics(provider, include_optional=True):
|
|
metrics = await get_metrics(provider=provider, include_optional=include_optional)
|
|
|
|
gt_path = Path(__file__).parent / "ground_truth_metrics.json"
|
|
with open(gt_path, "r") as file:
|
|
ground_truth_metrics = json.load(file)
|
|
|
|
if include_optional:
|
|
ground_truth_metrics = ground_truth_metrics["connected"]
|
|
else:
|
|
ground_truth_metrics = ground_truth_metrics["disconnected"]
|
|
|
|
diff_keys = set(metrics.keys()).symmetric_difference(set(ground_truth_metrics.keys()))
|
|
if diff_keys:
|
|
raise AssertionError(f"Metrics dictionaries have different keys: {diff_keys}")
|
|
|
|
for key, ground_truth_value in ground_truth_metrics.items():
|
|
assert metrics[key] == ground_truth_value, (
|
|
f"Expected {ground_truth_value} for '{key}' with {provider}, got {metrics[key]}"
|
|
)
|