cognee/cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py
lxobr 6223ecf05b
feat: optimize repeated entity extraction (#1682)
<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

- Added an `edge_text` field to edges that auto-fills from
`relationship_type` if not provided.
- Containts edges now store descriptions for better embedding
- Updated and refactored indexing so that edge_text gets embedded and
exposed
- Updated retrieval to use the new embeddings 
- Added a test to verify edge_text exists in the graph with the correct
format.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-10-30 13:56:06 +01:00

70 lines
2.3 KiB
Python

import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.tasks.storage.index_graph_edges import index_graph_edges
@pytest.mark.asyncio
async def test_index_graph_edges_success():
"""Test that index_graph_edges retrieves edges and delegates to index_data_points."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (
None,
[
[{"relationship_name": "rel1"}, {"relationship_name": "rel1"}],
[{"relationship_name": "rel2"}],
],
)
mock_index_data_points = AsyncMock()
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
"index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once()
mock_index_data_points.assert_awaited_once()
call_args = mock_index_data_points.call_args[0][0]
assert len(call_args) == 2
assert all(hasattr(item, "relationship_name") for item in call_args)
@pytest.mark.asyncio
async def test_index_graph_edges_no_relationships():
"""Test that index_graph_edges handles empty relationships correctly."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, [])
mock_index_data_points = AsyncMock()
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
"index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once()
mock_index_data_points.assert_awaited_once()
call_args = mock_index_data_points.call_args[0][0]
assert len(call_args) == 0
@pytest.mark.asyncio
async def test_index_graph_edges_initialization_error():
"""Test that index_graph_edges raises a RuntimeError if initialization fails."""
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(side_effect=Exception("Graph engine failed")),
"get_vector_engine": lambda: AsyncMock(),
},
):
with pytest.raises(RuntimeError, match="Initialization error"):
await index_graph_edges()