cognee/cognee/tasks/storage/index_graph_edges.py
lxobr 6223ecf05b
feat: optimize repeated entity extraction (#1682)
<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

- Added an `edge_text` field to edges that auto-fills from
`relationship_type` if not provided.
- Containts edges now store descriptions for better embedding
- Updated and refactored indexing so that edge_text gets embedded and
exposed
- Updated retrieval to use the new embeddings 
- Added a test to verify edge_text exists in the graph with the correct
format.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-10-30 13:56:06 +01:00

77 lines
2.6 KiB
Python

from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
from cognee.tasks.storage.index_data_points import index_data_points
logger = get_logger()
def _get_edge_text(item: dict) -> str:
"""Extract edge text for embedding - prefers edge_text field with fallback."""
if "edge_text" in item:
return item["edge_text"]
if "relationship_name" in item:
return item["relationship_name"]
return ""
def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
"""Transform raw edge data into EdgeType datapoints."""
edge_texts = [
_get_edge_text(item)
for edge in edges_data
for item in edge
if isinstance(item, dict) and "relationship_name" in item
]
edge_types = Counter(edge_texts)
return [
EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
for text, count in edge_types.items()
]
async def index_graph_edges(
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
):
"""
Indexes graph edges by creating and managing vector indexes for relationship types.
This function retrieves edge data from the graph engine, counts distinct relationship
types, and creates `EdgeType` pydantic objects. It ensures that vector indexes are created for
the `relationship_name` field.
Steps:
1. Initialize the graph engine if needed and retrieve edge data.
2. Transform edge data into EdgeType datapoints.
3. Index the EdgeType datapoints using the standard indexing function.
Raises:
RuntimeError: If initialization of the graph engine fails.
Returns:
None
"""
try:
if edges_data is None:
graph_engine = await get_graph_engine()
_, edges_data = await graph_engine.get_graph_data()
logger.warning(
"Your graph edge embedding is deprecated, please pass edges to the index_graph_edges directly."
)
except Exception as e:
logger.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
edge_type_datapoints = create_edge_type_datapoints(edges_data)
await index_data_points(edge_type_datapoints)
return None