cognee/cognee/tests/test_edge_ingestion.py
lxobr 6223ecf05b
feat: optimize repeated entity extraction (#1682)
<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

- Added an `edge_text` field to edges that auto-fills from
`relationship_type` if not provided.
- Containts edges now store descriptions for better embedding
- Updated and refactored indexing so that edge_text gets embedded and
exposed
- Updated retrieval to use the new embeddings 
- Added a test to verify edge_text exists in the graph with the correct
format.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-10-30 13:56:06 +01:00

115 lines
4.3 KiB
Python
Executable file

import os
import asyncio
import cognee
import pathlib
from cognee.infrastructure.databases.graph import get_graph_engine
from collections import Counter
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def test_edge_ingestion():
"""
Tests whether we ingest additional entity to entity edges
"""
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_edge_ingestion")
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_edge_ingestion")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
basic_nested_edges = ["is_a", "is_part_of", "contains", "made_from"]
entity_to_entity_edges = ["likes", "prefers", "watches"]
text1 = "Dave watches Dexter Resurrection"
text2 = "Ana likes apples"
text3 = "Bob prefers Cognee over other solutions"
await cognee.add([text1, text2, text3], dataset_name="edge_ingestion_test")
user = await get_default_user()
await cognee.cognify(["edge_ingestion_test"], user=user)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
"Tests edge_text presence and format"
contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
edge_properties = contains_edges[0][3]
assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
edge_text = edge_properties["edge_text"]
assert "relationship_name: contains" in edge_text, (
f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
)
assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}"
assert "entity_description:" in edge_text, (
f"Expected 'entity_description:' in edge_text, got: {edge_text}"
)
all_edge_texts = [
edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3]
]
expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"]
found_entity = any(
any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts
)
assert found_entity, (
f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}"
)
"Tests the presence of basic nested edges"
for basic_nested_edge in basic_nested_edges:
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (
f"Expected at least one {basic_nested_edge} edge, but found {edge_type_counts.get(basic_nested_edge, 0)}"
)
"Tests the presence of additional entity to entity edges"
assert len(edge_type_counts) > 4, (
f"Expected at least {5} edges (4 structural plus entity to entity edges), but found only {len(edge_type_counts)}"
)
"Tests the consistency of basic nested edges"
assert edge_type_counts.get("made_from", 0) == edge_type_counts.get("is_part_of", 0), (
f"Number of made_from and is_part_of edges are not matching, found {edge_type_counts.get('made_from', 0)} made from and {edge_type_counts.get('is_part_of', 0)} is_part_of."
)
"Tests whether we generate is_a for all entity that is contained by a chunk"
assert edge_type_counts.get("contains", 0) == edge_type_counts.get("is_a", 0), (
f"Number of contains and is_a edges are not matching, found {edge_type_counts.get('is_a', 0)} is_a and {edge_type_counts.get('is_part_of', 0)} contains."
)
found_edges = 0
for entity_to_entity_edge in entity_to_entity_edges:
if entity_to_entity_edge in edge_type_counts:
found_edges = found_edges + 1
"Tests the presence of extected entity to entity edges"
assert found_edges >= 2, (
f"Expected at least 2 entity to entity edges, but found only {found_edges}"
)
if __name__ == "__main__":
asyncio.run(test_edge_ingestion())