feat: optimize repeated entity extraction (#1682)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

- Added an `edge_text` field to edges that auto-fills from
`relationship_type` if not provided.
- Containts edges now store descriptions for better embedding
- Updated and refactored indexing so that edge_text gets embedded and
exposed
- Updated retrieval to use the new embeddings 
- Added a test to verify edge_text exists in the graph with the correct
format.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
lxobr 2025-10-30 13:56:06 +01:00 committed by GitHub
parent 116579a8e1
commit 6223ecf05b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 236 additions and 162 deletions

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, Any, Dict
@ -18,9 +18,21 @@ class Edge(BaseModel):
# Mixed usage
has_items: (Edge(weight=0.5, weights={"confidence": 0.9}), list[Item])
# With edge_text for rich embedding representation
contains: (Edge(relationship_type="contains", edge_text="relationship_name: contains; entity_description: Alice"), Entity)
"""
weight: Optional[float] = None
weights: Optional[Dict[str, float]] = None
relationship_type: Optional[str] = None
properties: Optional[Dict[str, Any]] = None
edge_text: Optional[str] = None
@field_validator("edge_text", mode="before")
@classmethod
def ensure_edge_text(cls, v, info):
"""Auto-populate edge_text from relationship_type if not explicitly provided."""
if v is None and info.data.get("relationship_type"):
return info.data["relationship_type"]
return v

View file

@ -1,6 +1,7 @@
from typing import List, Union
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
from cognee.tasks.temporal_graph.models import Event
@ -31,6 +32,6 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Union[Entity, Event]] = None
contains: List[Union[Entity, Event, tuple[Edge, Entity]]] = None
metadata: dict = {"index_fields": ["text"]}

View file

@ -171,8 +171,10 @@ class CogneeGraph(CogneeAbstractGraph):
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
distance = embedding_map.get(relationship_type, None)
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
"relationship_type"
)
distance = embedding_map.get(edge_key, None)
if distance is not None:
edge.attributes["vector_distance"] = distance

View file

@ -1,5 +1,6 @@
from typing import Optional
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
@ -243,10 +244,26 @@ def _process_graph_nodes(
ontology_relationships,
)
# Add entity to data chunk
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)
edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_node.name}",
f"entity_description: {entity_node.description}",
]
)
data_chunk.contains.append(
(
Edge(
relationship_type="contains",
edge_text=edge_text,
),
entity_node,
)
)
def _process_graph_edges(

View file

@ -1,71 +1,70 @@
import string
from typing import List
from collections import Counter
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
def _get_top_n_frequent_words(
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
) -> str:
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id in nodes:
continue
text = node.attributes.get("text")
if text:
name = _create_title_from_text(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""
Converts retrieved graph edges into a human-readable string format.
"""Converts retrieved graph edges into a human-readable string format."""
nodes = _extract_nodes_from_edges(retrieved_edges)
Parameters:
-----------
- retrieved_edges (list): A list of edges retrieved from the graph.
Returns:
--------
- str: A formatted string representation of the nodes and their connections.
"""
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
stop_words = DEFAULT_STOP_WORDS
import string
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
from collections import Counter
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
"""Creates a title, by combining first words with most frequent words from the text."""
first_words = text.split()[:first_n_words]
top_words = _top_n_words(text, top_n=first_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = _get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
nodes = _get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
connections = []
for edge in retrieved_edges:
source_name = nodes[edge.node1.id]["name"]
target_name = nodes[edge.node2.id]["name"]
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
connection_section = "\n".join(connections)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"

View file

@ -71,7 +71,7 @@ async def get_memory_fragment(
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
)

View file

@ -8,47 +8,58 @@ logger = get_logger("index_data_points")
async def index_data_points(data_points: list[DataPoint]):
created_indexes = {}
index_points = {}
"""Index data points in the vector engine by creating embeddings for specified fields.
Process:
1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
2. Creates vector indexes for each (type, field) combination on first encounter
3. Batches points per (type, field) and creates async indexing tasks
4. Executes all indexing tasks in parallel for efficient embedding generation
Args:
data_points: List of DataPoint objects to index. Each DataPoint's metadata must
contain an 'index_fields' list specifying which fields to embed.
Returns:
The original data_points list.
"""
data_points_by_type = {}
vector_engine = get_vector_engine()
for data_point in data_points:
data_point_type = type(data_point)
type_name = data_point_type.__name__
for field_name in data_point.metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue
index_name = f"{data_point_type.__name__}_{field_name}"
if type_name not in data_points_by_type:
data_points_by_type[type_name] = {}
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
if field_name not in data_points_by_type[type_name]:
await vector_engine.create_vector_index(type_name, field_name)
data_points_by_type[type_name][field_name] = []
indexed_data_point = data_point.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
data_points_by_type[type_name][field_name].append(indexed_data_point)
tasks: list[asyncio.Task] = []
batch_size = vector_engine.embedding_engine.get_batch_size()
for index_name_and_field, points in index_points.items():
first = index_name_and_field.index("_")
index_name = index_name_and_field[:first]
field_name = index_name_and_field[first + 1 :]
batches = (
(type_name, field_name, points[i : i + batch_size])
for type_name, fields in data_points_by_type.items()
for field_name, points in fields.items()
for i in range(0, len(points), batch_size)
)
# Create embedding requests per batch to run in parallel later
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
tasks.append(
asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
)
tasks = [
asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
for type_name, field_name, batch_points in batches
]
# Run all embedding requests in parallel
await asyncio.gather(*tasks)
return data_points

View file

@ -1,17 +1,44 @@
import asyncio
from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger
from collections import Counter
from typing import Optional, Dict, Any, List, Tuple, Union
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
from cognee.tasks.storage.index_data_points import index_data_points
logger = get_logger()
def _get_edge_text(item: dict) -> str:
"""Extract edge text for embedding - prefers edge_text field with fallback."""
if "edge_text" in item:
return item["edge_text"]
if "relationship_name" in item:
return item["relationship_name"]
return ""
def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
"""Transform raw edge data into EdgeType datapoints."""
edge_texts = [
_get_edge_text(item)
for edge in edges_data
for item in edge
if isinstance(item, dict) and "relationship_name" in item
]
edge_types = Counter(edge_texts)
return [
EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
for text, count in edge_types.items()
]
async def index_graph_edges(
edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
):
@ -23,24 +50,17 @@ async def index_graph_edges(
the `relationship_name` field.
Steps:
1. Initialize the vector engine and graph engine.
2. Retrieve graph edge data and count relationship types (`relationship_name`).
3. Create vector indexes for `relationship_name` if they don't exist.
4. Transform the counted relationships into `EdgeType` objects.
5. Index the transformed data points in the vector engine.
1. Initialize the graph engine if needed and retrieve edge data.
2. Transform edge data into EdgeType datapoints.
3. Index the EdgeType datapoints using the standard indexing function.
Raises:
RuntimeError: If initialization of the vector engine or graph engine fails.
RuntimeError: If initialization of the graph engine fails.
Returns:
None
"""
try:
created_indexes = {}
index_points = {}
vector_engine = get_vector_engine()
if edges_data is None:
graph_engine = await get_graph_engine()
_, edges_data = await graph_engine.get_graph_data()
@ -51,47 +71,7 @@ async def index_graph_edges(
logger.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
edge_types = Counter(
item.get("relationship_name")
for edge in edges_data
for item in edge
if isinstance(item, dict) and "relationship_name" in item
)
for text, count in edge_types.items():
edge = EdgeType(
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
)
data_point_type = type(edge)
for field_name in edge.metadata["index_fields"]:
index_name = f"{data_point_type.__name__}.{field_name}"
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
indexed_data_point = edge.model_copy()
indexed_data_point.metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
# Get maximum batch size for embedding model
batch_size = vector_engine.embedding_engine.get_batch_size()
tasks: list[asyncio.Task] = []
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
# Create embedding tasks to run in parallel later
for start in range(0, len(indexable_points), batch_size):
batch = indexable_points[start : start + batch_size]
tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
# Start all embedding tasks and wait for completion
await asyncio.gather(*tasks)
edge_type_datapoints = create_edge_type_datapoints(edges_data)
await index_data_points(edge_type_datapoints)
return None

View file

@ -52,6 +52,33 @@ async def test_edge_ingestion():
edge_type_counts = Counter(edge_type[2] for edge_type in graph[1])
"Tests edge_text presence and format"
contains_edges = [edge for edge in graph[1] if edge[2] == "contains"]
assert len(contains_edges) > 0, "Expected at least one contains edge for edge_text verification"
edge_properties = contains_edges[0][3]
assert "edge_text" in edge_properties, "Expected edge_text in edge properties"
edge_text = edge_properties["edge_text"]
assert "relationship_name: contains" in edge_text, (
f"Expected 'relationship_name: contains' in edge_text, got: {edge_text}"
)
assert "entity_name:" in edge_text, f"Expected 'entity_name:' in edge_text, got: {edge_text}"
assert "entity_description:" in edge_text, (
f"Expected 'entity_description:' in edge_text, got: {edge_text}"
)
all_edge_texts = [
edge[3].get("edge_text", "") for edge in contains_edges if "edge_text" in edge[3]
]
expected_entities = ["dave", "ana", "bob", "dexter", "apples", "cognee"]
found_entity = any(
any(entity in text.lower() for entity in expected_entities) for text in all_edge_texts
)
assert found_entity, (
f"Expected to find at least one entity name in edge_text: {all_edge_texts[:3]}"
)
"Tests the presence of basic nested edges"
for basic_nested_edge in basic_nested_edges:
assert edge_type_counts.get(basic_nested_edge, 0) >= 1, (

View file

@ -0,0 +1,27 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.tasks.storage.index_data_points import index_data_points
from cognee.infrastructure.engine import DataPoint
class TestDataPoint(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
@pytest.mark.asyncio
async def test_index_data_points_calls_vector_engine():
"""Test that index_data_points creates vector index and indexes data."""
data_points = [TestDataPoint(name="test1")]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
with patch.dict(
index_data_points.__globals__,
{"get_vector_engine": lambda: mock_vector_engine},
):
await index_data_points(data_points)
assert mock_vector_engine.create_vector_index.await_count >= 1
assert mock_vector_engine.index_data_points.await_count >= 1

View file

@ -5,8 +5,7 @@ from cognee.tasks.storage.index_graph_edges import index_graph_edges
@pytest.mark.asyncio
async def test_index_graph_edges_success():
"""Test that index_graph_edges uses the index datapoints and creates vector index."""
# Create the mocks for the graph and vector engines.
"""Test that index_graph_edges retrieves edges and delegates to index_data_points."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (
None,
@ -15,26 +14,23 @@ async def test_index_graph_edges_success():
[{"relationship_name": "rel2"}],
],
)
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine.get_batch_size = MagicMock(return_value=100)
mock_index_data_points = AsyncMock()
# Patch the globals of the function so that when it does:
# vector_engine = get_vector_engine()
# graph_engine = await get_graph_engine()
# it uses the mocked versions.
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
"get_vector_engine": lambda: mock_vector_engine,
"index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
# Assertions on the mock calls.
mock_graph_engine.get_graph_data.assert_awaited_once()
assert mock_vector_engine.create_vector_index.await_count == 1
assert mock_vector_engine.index_data_points.await_count == 1
mock_index_data_points.assert_awaited_once()
call_args = mock_index_data_points.call_args[0][0]
assert len(call_args) == 2
assert all(hasattr(item, "relationship_name") for item in call_args)
@pytest.mark.asyncio
@ -42,20 +38,22 @@ async def test_index_graph_edges_no_relationships():
"""Test that index_graph_edges handles empty relationships correctly."""
mock_graph_engine = AsyncMock()
mock_graph_engine.get_graph_data.return_value = (None, [])
mock_vector_engine = AsyncMock()
mock_index_data_points = AsyncMock()
with patch.dict(
index_graph_edges.__globals__,
{
"get_graph_engine": AsyncMock(return_value=mock_graph_engine),
"get_vector_engine": lambda: mock_vector_engine,
"index_data_points": mock_index_data_points,
},
):
await index_graph_edges()
mock_graph_engine.get_graph_data.assert_awaited_once()
mock_vector_engine.create_vector_index.assert_not_awaited()
mock_vector_engine.index_data_points.assert_not_awaited()
mock_index_data_points.assert_awaited_once()
call_args = mock_index_data_points.call_args[0][0]
assert len(call_args) == 0
@pytest.mark.asyncio