feat: Adds edge centered payload and embedding structure during ingestion (#1853)

<!-- .github/pull_request_template.md -->

## Description
This pull request introduces edge‑centered payloads to the ingestion
process. Payloads are stored in the Triplet_text collection which is
compatible with the triplet_embedding memify pipeline.

Changes in This PR:

- Refactored custom edge handling, from now on they can be passed to the
add_data_points method so the ingestion is centralized and is happening
in one place.
- Added private methods to handle edge centered payload creation inside
the add_data_points.py
- Added unit tests to cover the new functionality
- Added integration tests
- Added e2e tests

Acceptance Criteria and Testing
Scenario 1:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB contains a non empty Triplet_text collection and
the number of triplets are matching with the number of edges in the
graph database
-Use the new triplet_completion search type and confirm it works
correctly.

Scenario 2:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB does not have the Triplet_text collection 
-You should receive an error indicating that the Triplet_text is not
available


## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Triplet embeddings supported—embeddings created from graph edges plus
connected node text
  * Ability to supply custom edges when adding data points
  * New configuration toggle to enable/disable triplet embedding

* **Tests**
* Added comprehensive unit and end-to-end tests for edge-centered
payloads and triplet embedding
  * New CI job to run the edge-centered payload e2e test

* **Bug Fixes**
* Adjusted server start behavior to surface process output in parent
logs

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
This commit is contained in:
hajdul88 2025-12-10 17:10:06 +01:00 committed by GitHub
parent 49f7c5188c
commit 001fbe699e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 786 additions and 14 deletions

View file

@ -412,6 +412,35 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_feedback_enrichment.py
test-edge-centered-payload:
name: Test Cognify - Edge Centered Payload
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run Edge Centered Payload Test
env:
ENV: 'dev'
TRIPLET_EMBEDDING: True
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_edge_centered_payload.py
run_conversation_sessions_test_redis:
name: Conversation sessions test (Redis)
runs-on: ubuntu-latest

View file

@ -3,6 +3,7 @@ from pydantic import BaseModel
from typing import Union, Optional
from uuid import UUID
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import KnowledgeGraph
@ -272,6 +273,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
if chunks_per_batch is None:
chunks_per_batch = 100
cognify_config = get_cognify_config()
embed_triplets = cognify_config.triplet_embedding
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
@ -291,7 +295,11 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
summarize_text,
task_config={"batch_size": chunks_per_batch},
),
Task(add_data_points, task_config={"batch_size": chunks_per_batch}),
Task(
add_data_points,
embed_triplets=embed_triplets,
task_config={"batch_size": chunks_per_batch},
),
]
return default_tasks

View file

@ -8,12 +8,14 @@ import os
class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent
triplet_embedding: bool = False
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
"classification_model": self.classification_model,
"summarization_model": self.summarization_model,
"triplet_embedding": self.triplet_embedding,
}

View file

@ -2,9 +2,7 @@ import asyncio
from typing import Type, List, Optional
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.ontology.ontology_env_config import get_ontology_env_config
from cognee.tasks.storage import index_graph_edges
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.get_default_ontology_resolver import (
@ -25,6 +23,7 @@ from cognee.tasks.graph.exceptions import (
InvalidChunkGraphInputError,
InvalidOntologyAdapterError,
)
from cognee.modules.cognify.config import get_cognify_config
async def integrate_chunk_graphs(
@ -67,8 +66,6 @@ async def integrate_chunk_graphs(
type(ontology_resolver).__name__ if ontology_resolver else "None"
)
graph_engine = await get_graph_engine()
if graph_model is not KnowledgeGraph:
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
@ -84,12 +81,13 @@ async def integrate_chunk_graphs(
data_chunks, chunk_graphs, ontology_resolver, existing_edges_map
)
if len(graph_nodes) > 0:
await add_data_points(graph_nodes)
cognify_config = get_cognify_config()
embed_triplets = cognify_config.triplet_embedding
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
await index_graph_edges(graph_edges)
if len(graph_nodes) > 0:
await add_data_points(
data_points=graph_nodes, custom_edges=graph_edges, embed_triplets=embed_triplets
)
return data_chunks

View file

@ -1,16 +1,23 @@
import asyncio
from typing import List
from typing import List, Dict, Optional
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
from cognee.modules.engine.models import Triplet
from cognee.shared.logging_utils import get_logger
from cognee.tasks.storage.exceptions import (
InvalidDataPointsInAddDataPointsError,
)
from ...modules.engine.utils import generate_node_id
logger = get_logger("add_data_points")
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
async def add_data_points(
data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False
) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
@ -23,6 +30,10 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
Args:
data_points (List[DataPoint]):
A list of data points to process and insert into the graph.
custom_edges (List[tuple]): Custom edges between datapoints.
embed_triplets (bool):
If True, creates and indexes triplet embeddings from the graph structure.
Defaults to False.
Returns:
List[DataPoint]:
@ -34,6 +45,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
- Updates the node index via `index_data_points`.
- Inserts nodes and edges into the graph engine.
- Optionally updates the edge index via `index_graph_edges`.
- Optionally creates and indexes triplet embeddings if embed_triplets is True.
"""
if not isinstance(data_points, list):
@ -74,4 +86,132 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
await graph_engine.add_edges(edges)
await index_graph_edges(edges)
if isinstance(custom_edges, list) and custom_edges:
# This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488)
await graph_engine.add_edges(custom_edges)
await index_graph_edges(custom_edges)
edges.extend(custom_edges)
if embed_triplets:
triplets = _create_triplets_from_graph(nodes, edges)
if triplets:
await index_data_points(triplets)
logger.info(f"Created and indexed {len(triplets)} triplets from graph structure")
return data_points
def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str:
"""
Extract embeddable text from a DataPoint using its index_fields metadata.
Uses the same approach as index_data_points.
Parameters:
-----------
- data_point (DataPoint): The data point to extract text from.
Returns:
--------
- str: Concatenated string of all embeddable property values, or empty string if none found.
"""
if not data_point or not hasattr(data_point, "metadata"):
return ""
index_fields = data_point.metadata.get("index_fields", [])
if not index_fields:
return ""
embeddable_values = []
for field_name in index_fields:
field_value = getattr(data_point, field_name, None)
if field_value is not None:
field_value = str(field_value).strip()
if field_value:
embeddable_values.append(field_value)
return " ".join(embeddable_values) if embeddable_values else ""
def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]:
"""
Create Triplet objects from graph nodes and edges.
This function processes graph edges and their corresponding nodes to create
triplet datapoints with embeddable text, similar to the triplet embeddings pipeline.
Parameters:
-----------
- nodes (List[DataPoint]): List of graph nodes extracted from data points
- edges (List[tuple]): List of edge tuples in format
(source_node_id, target_node_id, relationship_name, properties_dict)
Note: All edges including those from DocumentChunk.contains are already extracted
by get_graph_from_model and included in this list.
Returns:
--------
- List[Triplet]: List of Triplet objects ready for indexing
"""
node_map: Dict[str, DataPoint] = {}
for node in nodes:
if hasattr(node, "id"):
node_id = str(node.id)
if node_id not in node_map:
node_map[node_id] = node
triplets = []
skipped_count = 0
seen_ids = set()
for edge_tuple in edges:
if len(edge_tuple) < 4:
continue
source_node_id, target_node_id, relationship_name, edge_properties = (
edge_tuple[0],
edge_tuple[1],
edge_tuple[2],
edge_tuple[3],
)
source_node = node_map.get(str(source_node_id))
target_node = node_map.get(str(target_node_id))
if not source_node or not target_node or relationship_name is None:
skipped_count += 1
continue
source_node_text = _extract_embeddable_text_from_datapoint(source_node)
target_node_text = _extract_embeddable_text_from_datapoint(target_node)
relationship_text = ""
if isinstance(edge_properties, dict):
edge_text = edge_properties.get("edge_text")
if edge_text and isinstance(edge_text, str) and edge_text.strip():
relationship_text = edge_text.strip()
if not relationship_text and relationship_name:
relationship_text = relationship_name
if not source_node_text and not relationship_text and not relationship_name:
skipped_count += 1
continue
embeddable_text = f"{source_node_text} - {relationship_text}-{target_node_text}".strip()
triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id))
if triplet_id in seen_ids:
continue
seen_ids.add(triplet_id)
triplets.append(
Triplet(
id=triplet_id,
from_node_id=str(source_node_id),
to_node_id=str(target_node_id),
text=embeddable_text,
)
)
return triplets

View file

@ -0,0 +1,139 @@
import pathlib
import pytest
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.infrastructure.engine import DataPoint
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.tasks.storage.exceptions import InvalidDataPointsInAddDataPointsError
from cognee.infrastructure.databases.graph import get_graph_engine
class Person(DataPoint):
name: str
age: int
metadata: dict = {"index_fields": ["name"]}
class Company(DataPoint):
name: str
industry: str
metadata: dict = {"index_fields": ["name", "industry"]}
@pytest_asyncio.fixture
async def clean_test_environment():
"""Set up a clean test environment for add_data_points tests."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_add_data_points_integration")
data_directory_path = str(base_dir / ".data_storage/test_add_data_points_integration")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_add_data_points_comprehensive(clean_test_environment):
"""Comprehensive integration test for add_data_points functionality."""
person1 = Person(name="Alice", age=30)
person2 = Person(name="Bob", age=25)
result = await add_data_points([person1, person2])
assert result == [person1, person2]
assert len(result) == 2
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) >= 2
result_empty = await add_data_points([])
assert result_empty == []
person3 = Person(name="Charlie", age=35)
person4 = Person(name="Diana", age=32)
custom_edge = (str(person3.id), str(person4.id), "knows", {"edge_text": "friends with"})
result_custom = await add_data_points([person3, person4], custom_edges=[custom_edge])
assert len(result_custom) == 2
nodes, edges = await graph_engine.get_graph_data()
assert len(edges) == 1
assert len(nodes) == 4
class Employee(DataPoint):
name: str
works_at: Company
metadata: dict = {"index_fields": ["name"]}
company = Company(name="TechCorp", industry="Technology")
employee = Employee(name="Eve", works_at=company)
result_rel = await add_data_points([employee])
assert len(result_rel) == 1
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 6
assert len(edges) == 2
person5 = Person(name="Frank", age=40)
person6 = Person(name="Grace", age=38)
triplet_edge = (str(person5.id), str(person6.id), "married_to", {"edge_text": "is married to"})
result_triplet = await add_data_points(
[person5, person6], custom_edges=[triplet_edge], embed_triplets=True
)
assert len(result_triplet) == 2
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 8
assert len(edges) == 3
batch1 = [Person(name="Leo", age=25), Person(name="Mia", age=30)]
batch2 = [Person(name="Noah", age=35), Person(name="Olivia", age=40)]
result_batch1 = await add_data_points(batch1)
result_batch2 = await add_data_points(batch2)
assert len(result_batch1) == 2
assert len(result_batch2) == 2
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 12
assert len(edges) == 3
person7 = Person(name="Paul", age=33)
person8 = Person(name="Quinn", age=31)
edge1 = (str(person7.id), str(person8.id), "colleague_of", {"edge_text": "works with"})
edge2 = (str(person8.id), str(person7.id), "colleague_of", {"edge_text": "works with"})
result_bi = await add_data_points([person7, person8], custom_edges=[edge1, edge2])
assert len(result_bi) == 2
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 14
assert len(edges) == 5
person_invalid = Person(name="Invalid", age=50)
with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a list"):
await add_data_points(person_invalid)
with pytest.raises(InvalidDataPointsInAddDataPointsError, match="must be a DataPoint"):
await add_data_points(["not", "datapoints"])
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 14
assert len(final_edges) == 5

View file

@ -25,8 +25,6 @@ class TestCogneeServerStart(unittest.TestCase):
"--port",
"8000",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid,
)
# Give the server some time to start

View file

@ -0,0 +1,170 @@
"""
End-to-end integration test for edge-centered payload and triplet embeddings.
"""
import os
import pathlib
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
from cognee.modules.ontology.ontology_config import Config
logger = get_logger()
text_data = """
Apple is a technology company that produces the iPhone, iPad, and Mac computers.
The company is known for its innovative products and ecosystem integration.
Microsoft develops the Windows operating system and Office productivity suite.
They are also major players in cloud computing with Azure.
Google created the Android operating system and provides search engine services.
The company is a leader in artificial intelligence and machine learning.
"""
ontology_content = """<?xml version="1.0"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#"
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
xmlns="http://example.org/tech#"
xml:base="http://example.org/tech">
<owl:Ontology rdf:about="http://example.org/tech"/>
<!-- Classes -->
<owl:Class rdf:ID="Company"/>
<owl:Class rdf:ID="TechnologyCompany"/>
<owl:Class rdf:ID="Product"/>
<owl:Class rdf:ID="Software"/>
<owl:Class rdf:ID="Hardware"/>
<owl:Class rdf:ID="Service"/>
<rdf:Description rdf:about="#TechnologyCompany">
<rdfs:subClassOf rdf:resource="#Company"/>
<rdfs:comment>A company operating in the technology sector.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="#Software">
<rdfs:subClassOf rdf:resource="#Product"/>
<rdfs:comment>Software products and applications.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="#Hardware">
<rdfs:subClassOf rdf:resource="#Product"/>
<rdfs:comment>Physical hardware products.</rdfs:comment>
</rdf:Description>
<!-- Individuals -->
<TechnologyCompany rdf:ID="apple">
<rdfs:label>Apple</rdfs:label>
</TechnologyCompany>
<TechnologyCompany rdf:ID="microsoft">
<rdfs:label>Microsoft</rdfs:label>
</TechnologyCompany>
<TechnologyCompany rdf:ID="google">
<rdfs:label>Google</rdfs:label>
</TechnologyCompany>
<Hardware rdf:ID="iphone">
<rdfs:label>iPhone</rdfs:label>
</Hardware>
<Software rdf:ID="windows">
<rdfs:label>Windows</rdfs:label>
</Software>
<Software rdf:ID="android">
<rdfs:label>Android</rdfs:label>
</Software>
</rdf:RDF>"""
async def main():
data_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".data_storage/test_edge_centered_payload",
)
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(
pathlib.Path(__file__).parent,
".cognee_system/test_edge_centered_payload",
)
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
dataset_name = "tech_companies"
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(data=text_data, dataset_name=dataset_name)
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".owl", delete=False) as f:
f.write(ontology_content)
ontology_file_path = f.name
try:
logger.info(f"Loading ontology from: {ontology_file_path}")
config: Config = {
"ontology_config": {
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_file_path)
}
}
await cognee.cognify(datasets=[dataset_name], config=config)
graph_engine = await get_graph_engine()
nodes_phase2, edges_phase2 = await graph_engine.get_graph_data()
vector_engine = get_vector_engine()
triplets_phase2 = await vector_engine.search(
query_text="technology", limit=None, collection_name="Triplet_text"
)
assert len(triplets_phase2) == len(edges_phase2), (
f"Triplet embeddings and number of edges do not match. Vector db contains {len(triplets_phase2)} edge triplets while graph db contains {len(edges_phase2)} edges."
)
search_results_phase2 = await cognee.search(
query_type=SearchType.TRIPLET_COMPLETION,
query_text="What products does Apple make?",
)
assert search_results_phase2 is not None, (
"Search should return results for triplet embeddings in simple ontology use case."
)
finally:
if os.path.exists(ontology_file_path):
os.unlink(ontology_file_path)
if __name__ == "__main__":
import asyncio
from cognee.shared.logging_utils import setup_logging
setup_logging()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()

View file

@ -0,0 +1,288 @@
import pytest
from unittest.mock import AsyncMock, patch
import sys
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Triplet
from cognee.tasks.storage.add_data_points import (
add_data_points,
InvalidDataPointsInAddDataPointsError,
_extract_embeddable_text_from_datapoint,
_create_triplets_from_graph,
)
adp_module = sys.modules["cognee.tasks.storage.add_data_points"]
class SimplePoint(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
@pytest.mark.asyncio
@pytest.mark.parametrize("bad_input", [None, ["not_datapoint"]])
async def test_add_data_points_validates_inputs(bad_input):
with pytest.raises(InvalidDataPointsInAddDataPointsError):
await add_data_points(bad_input)
@pytest.mark.asyncio
@patch.object(adp_module, "index_graph_edges")
@patch.object(adp_module, "index_data_points")
@patch.object(adp_module, "get_graph_engine")
@patch.object(adp_module, "deduplicate_nodes_and_edges")
@patch.object(adp_module, "get_graph_from_model")
async def test_add_data_points_indexes_nodes_and_edges(
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
):
dp1 = SimplePoint(text="first")
dp2 = SimplePoint(text="second")
edge1 = (str(dp1.id), str(dp2.id), "related_to", {"edge_text": "connects"})
custom_edges = [(str(dp2.id), str(dp1.id), "custom_edge", {})]
mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])]
mock_dedup.side_effect = lambda n, e: (n, e)
graph_engine = AsyncMock()
mock_get_engine.return_value = graph_engine
result = await add_data_points([dp1, dp2], custom_edges=custom_edges)
assert result == [dp1, dp2]
graph_engine.add_nodes.assert_awaited_once()
mock_index_nodes.assert_awaited_once()
assert graph_engine.add_edges.await_count == 2
assert edge1 in graph_engine.add_edges.await_args_list[0].args[0]
assert graph_engine.add_edges.await_args_list[1].args[0] == custom_edges
assert mock_index_edges.await_count == 2
@pytest.mark.asyncio
@patch.object(adp_module, "index_graph_edges")
@patch.object(adp_module, "index_data_points")
@patch.object(adp_module, "get_graph_engine")
@patch.object(adp_module, "deduplicate_nodes_and_edges")
@patch.object(adp_module, "get_graph_from_model")
async def test_add_data_points_indexes_triplets_when_enabled(
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
):
dp1 = SimplePoint(text="source")
dp2 = SimplePoint(text="target")
edge1 = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "describes"})
mock_get_graph.side_effect = [([dp1], [edge1]), ([dp2], [])]
mock_dedup.side_effect = lambda n, e: (n, e)
graph_engine = AsyncMock()
mock_get_engine.return_value = graph_engine
await add_data_points([dp1, dp2], embed_triplets=True)
assert mock_index_nodes.await_count == 2
nodes_arg = mock_index_nodes.await_args_list[0].args[0]
triplets_arg = mock_index_nodes.await_args_list[1].args[0]
assert nodes_arg == [dp1, dp2]
assert len(triplets_arg) == 1
assert isinstance(triplets_arg[0], Triplet)
mock_index_edges.assert_awaited_once()
@pytest.mark.asyncio
@patch.object(adp_module, "index_graph_edges")
@patch.object(adp_module, "index_data_points")
@patch.object(adp_module, "get_graph_engine")
@patch.object(adp_module, "deduplicate_nodes_and_edges")
@patch.object(adp_module, "get_graph_from_model")
async def test_add_data_points_with_empty_list(
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
):
mock_dedup.side_effect = lambda n, e: (n, e)
graph_engine = AsyncMock()
mock_get_engine.return_value = graph_engine
result = await add_data_points([])
assert result == []
mock_get_graph.assert_not_called()
graph_engine.add_nodes.assert_awaited_once_with([])
@pytest.mark.asyncio
@patch.object(adp_module, "index_graph_edges")
@patch.object(adp_module, "index_data_points")
@patch.object(adp_module, "get_graph_engine")
@patch.object(adp_module, "deduplicate_nodes_and_edges")
@patch.object(adp_module, "get_graph_from_model")
async def test_add_data_points_with_single_datapoint(
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
):
dp = SimplePoint(text="single")
mock_get_graph.side_effect = [([dp], [])]
mock_dedup.side_effect = lambda n, e: (n, e)
graph_engine = AsyncMock()
mock_get_engine.return_value = graph_engine
result = await add_data_points([dp])
assert result == [dp]
mock_get_graph.assert_called_once()
mock_index_nodes.assert_awaited_once()
def test_extract_embeddable_text_from_datapoint():
dp = SimplePoint(text="hello world")
text = _extract_embeddable_text_from_datapoint(dp)
assert text == "hello world"
def test_extract_embeddable_text_with_multiple_fields():
class MultiField(DataPoint):
title: str
description: str
metadata: dict = {"index_fields": ["title", "description"]}
dp = MultiField(title="Test", description="Description")
text = _extract_embeddable_text_from_datapoint(dp)
assert text == "Test Description"
def test_extract_embeddable_text_with_no_index_fields():
class NoIndex(DataPoint):
text: str
metadata: dict = {"index_fields": []}
dp = NoIndex(text="ignored")
text = _extract_embeddable_text_from_datapoint(dp)
assert text == ""
def test_create_triplets_from_graph():
dp1 = SimplePoint(text="source node")
dp2 = SimplePoint(text="target node")
edge = (str(dp1.id), str(dp2.id), "connects_to", {"edge_text": "links"})
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
assert len(triplets) == 1
assert isinstance(triplets[0], Triplet)
assert triplets[0].from_node_id == str(dp1.id)
assert triplets[0].to_node_id == str(dp2.id)
assert "source node" in triplets[0].text
assert "target node" in triplets[0].text
def test_extract_embeddable_text_with_none_datapoint():
text = _extract_embeddable_text_from_datapoint(None)
assert text == ""
def test_extract_embeddable_text_without_metadata():
class NoMetadata(DataPoint):
text: str
dp = NoMetadata(text="test")
delattr(dp, "metadata")
text = _extract_embeddable_text_from_datapoint(dp)
assert text == ""
def test_extract_embeddable_text_with_whitespace_only():
class WhitespaceField(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
dp = WhitespaceField(text=" ")
text = _extract_embeddable_text_from_datapoint(dp)
assert text == ""
def test_create_triplets_skips_short_edge_tuples():
dp = SimplePoint(text="node")
incomplete_edge = (str(dp.id), str(dp.id))
triplets = _create_triplets_from_graph([dp], [incomplete_edge])
assert len(triplets) == 0
def test_create_triplets_skips_missing_source_node():
dp1 = SimplePoint(text="target")
edge = ("missing_id", str(dp1.id), "relates", {})
triplets = _create_triplets_from_graph([dp1], [edge])
assert len(triplets) == 0
def test_create_triplets_skips_missing_target_node():
dp1 = SimplePoint(text="source")
edge = (str(dp1.id), "missing_id", "relates", {})
triplets = _create_triplets_from_graph([dp1], [edge])
assert len(triplets) == 0
def test_create_triplets_skips_none_relationship():
dp1 = SimplePoint(text="source")
dp2 = SimplePoint(text="target")
edge = (str(dp1.id), str(dp2.id), None, {})
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
assert len(triplets) == 0
def test_create_triplets_uses_relationship_name_when_no_edge_text():
dp1 = SimplePoint(text="source")
dp2 = SimplePoint(text="target")
edge = (str(dp1.id), str(dp2.id), "connects_to", {})
triplets = _create_triplets_from_graph([dp1, dp2], [edge])
assert len(triplets) == 1
assert "connects_to" in triplets[0].text
def test_create_triplets_prevents_duplicates():
dp1 = SimplePoint(text="source")
dp2 = SimplePoint(text="target")
edge = (str(dp1.id), str(dp2.id), "relates", {"edge_text": "links"})
triplets = _create_triplets_from_graph([dp1, dp2], [edge, edge])
assert len(triplets) == 1
def test_create_triplets_skips_nodes_without_id():
class NodeNoId:
pass
dp = SimplePoint(text="valid")
node_no_id = NodeNoId()
edge = (str(dp.id), "some_id", "relates", {})
triplets = _create_triplets_from_graph([dp, node_no_id], [edge])
assert len(triplets) == 0
@pytest.mark.asyncio
@patch.object(adp_module, "index_graph_edges")
@patch.object(adp_module, "index_data_points")
@patch.object(adp_module, "get_graph_engine")
@patch.object(adp_module, "deduplicate_nodes_and_edges")
@patch.object(adp_module, "get_graph_from_model")
async def test_add_data_points_with_empty_custom_edges(
mock_get_graph, mock_dedup, mock_get_engine, mock_index_nodes, mock_index_edges
):
dp = SimplePoint(text="test")
mock_get_graph.side_effect = [([dp], [])]
mock_dedup.side_effect = lambda n, e: (n, e)
graph_engine = AsyncMock()
mock_get_engine.return_value = graph_engine
result = await add_data_points([dp], custom_edges=[])
assert result == [dp]
assert graph_engine.add_edges.await_count == 1