<!-- .github/pull_request_template.md -->
## Description
This pull request introduces edge‑centered payloads to the ingestion
process. Payloads are stored in the Triplet_text collection which is
compatible with the triplet_embedding memify pipeline.
Changes in This PR:
- Refactored custom edge handling, from now on they can be passed to the
add_data_points method so the ingestion is centralized and is happening
in one place.
- Added private methods to handle edge centered payload creation inside
the add_data_points.py
- Added unit tests to cover the new functionality
- Added integration tests
- Added e2e tests
Acceptance Criteria and Testing
Scenario 1:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB contains a non empty Triplet_text collection and
the number of triplets are matching with the number of edges in the
graph database
-Use the new triplet_completion search type and confirm it works
correctly.
Scenario 2:
-Set TRIPLET_EMBEDDING env var to True
-Run prune, add, cognify
-Verify the vector DB does not have the Triplet_text collection
-You should receive an error indicating that the Triplet_text is not
available
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Triplet embeddings supported—embeddings created from graph edges plus
connected node text
* Ability to supply custom edges when adding data points
* New configuration toggle to enable/disable triplet embedding
* **Tests**
* Added comprehensive unit and end-to-end tests for edge-centered
payloads and triplet embedding
* New CI job to run the edge-centered payload e2e test
* **Bug Fixes**
* Adjusted server start behavior to surface process output in parent
logs
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
217 lines
7.5 KiB
Python
217 lines
7.5 KiB
Python
import asyncio
|
||
from typing import List, Dict, Optional
|
||
from cognee.infrastructure.engine import DataPoint
|
||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||
from .index_data_points import index_data_points
|
||
from .index_graph_edges import index_graph_edges
|
||
from cognee.modules.engine.models import Triplet
|
||
from cognee.shared.logging_utils import get_logger
|
||
from cognee.tasks.storage.exceptions import (
|
||
InvalidDataPointsInAddDataPointsError,
|
||
)
|
||
from ...modules.engine.utils import generate_node_id
|
||
|
||
logger = get_logger("add_data_points")
|
||
|
||
|
||
async def add_data_points(
|
||
data_points: List[DataPoint], custom_edges: Optional[List] = None, embed_triplets: bool = False
|
||
) -> List[DataPoint]:
|
||
"""
|
||
Add a batch of data points to the graph database by extracting nodes and edges,
|
||
deduplicating them, and indexing them for retrieval.
|
||
|
||
This function parallelizes the graph extraction for each data point,
|
||
merges the resulting nodes and edges, and ensures uniqueness before
|
||
committing them to the underlying graph engine. It also updates the
|
||
associated retrieval indices for nodes and (optionally) edges.
|
||
|
||
Args:
|
||
data_points (List[DataPoint]):
|
||
A list of data points to process and insert into the graph.
|
||
custom_edges (List[tuple]): Custom edges between datapoints.
|
||
embed_triplets (bool):
|
||
If True, creates and indexes triplet embeddings from the graph structure.
|
||
Defaults to False.
|
||
|
||
Returns:
|
||
List[DataPoint]:
|
||
The original list of data points after processing and insertion.
|
||
|
||
Side Effects:
|
||
- Calls `get_graph_from_model` concurrently for each data point.
|
||
- Deduplicates nodes and edges across all results.
|
||
- Updates the node index via `index_data_points`.
|
||
- Inserts nodes and edges into the graph engine.
|
||
- Optionally updates the edge index via `index_graph_edges`.
|
||
- Optionally creates and indexes triplet embeddings if embed_triplets is True.
|
||
"""
|
||
|
||
if not isinstance(data_points, list):
|
||
raise InvalidDataPointsInAddDataPointsError("data_points must be a list.")
|
||
if not all(isinstance(dp, DataPoint) for dp in data_points):
|
||
raise InvalidDataPointsInAddDataPointsError("data_points: each item must be a DataPoint.")
|
||
|
||
nodes = []
|
||
edges = []
|
||
|
||
added_nodes = {}
|
||
added_edges = {}
|
||
visited_properties = {}
|
||
|
||
results = await asyncio.gather(
|
||
*[
|
||
get_graph_from_model(
|
||
data_point,
|
||
added_nodes=added_nodes,
|
||
added_edges=added_edges,
|
||
visited_properties=visited_properties,
|
||
)
|
||
for data_point in data_points
|
||
]
|
||
)
|
||
|
||
for result_nodes, result_edges in results:
|
||
nodes.extend(result_nodes)
|
||
edges.extend(result_edges)
|
||
|
||
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
||
|
||
graph_engine = await get_graph_engine()
|
||
|
||
await graph_engine.add_nodes(nodes)
|
||
await index_data_points(nodes)
|
||
|
||
await graph_engine.add_edges(edges)
|
||
await index_graph_edges(edges)
|
||
|
||
if isinstance(custom_edges, list) and custom_edges:
|
||
# This must be handled separately from datapoint edges, created a task in linear to dig deeper but (COG-3488)
|
||
await graph_engine.add_edges(custom_edges)
|
||
await index_graph_edges(custom_edges)
|
||
edges.extend(custom_edges)
|
||
|
||
if embed_triplets:
|
||
triplets = _create_triplets_from_graph(nodes, edges)
|
||
if triplets:
|
||
await index_data_points(triplets)
|
||
logger.info(f"Created and indexed {len(triplets)} triplets from graph structure")
|
||
|
||
return data_points
|
||
|
||
|
||
def _extract_embeddable_text_from_datapoint(data_point: DataPoint) -> str:
|
||
"""
|
||
Extract embeddable text from a DataPoint using its index_fields metadata.
|
||
Uses the same approach as index_data_points.
|
||
|
||
Parameters:
|
||
-----------
|
||
- data_point (DataPoint): The data point to extract text from.
|
||
|
||
Returns:
|
||
--------
|
||
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
||
"""
|
||
if not data_point or not hasattr(data_point, "metadata"):
|
||
return ""
|
||
|
||
index_fields = data_point.metadata.get("index_fields", [])
|
||
if not index_fields:
|
||
return ""
|
||
|
||
embeddable_values = []
|
||
for field_name in index_fields:
|
||
field_value = getattr(data_point, field_name, None)
|
||
if field_value is not None:
|
||
field_value = str(field_value).strip()
|
||
|
||
if field_value:
|
||
embeddable_values.append(field_value)
|
||
|
||
return " ".join(embeddable_values) if embeddable_values else ""
|
||
|
||
|
||
def _create_triplets_from_graph(nodes: List[DataPoint], edges: List[tuple]) -> List[Triplet]:
|
||
"""
|
||
Create Triplet objects from graph nodes and edges.
|
||
|
||
This function processes graph edges and their corresponding nodes to create
|
||
triplet datapoints with embeddable text, similar to the triplet embeddings pipeline.
|
||
|
||
Parameters:
|
||
-----------
|
||
- nodes (List[DataPoint]): List of graph nodes extracted from data points
|
||
- edges (List[tuple]): List of edge tuples in format
|
||
(source_node_id, target_node_id, relationship_name, properties_dict)
|
||
Note: All edges including those from DocumentChunk.contains are already extracted
|
||
by get_graph_from_model and included in this list.
|
||
|
||
Returns:
|
||
--------
|
||
- List[Triplet]: List of Triplet objects ready for indexing
|
||
"""
|
||
node_map: Dict[str, DataPoint] = {}
|
||
for node in nodes:
|
||
if hasattr(node, "id"):
|
||
node_id = str(node.id)
|
||
if node_id not in node_map:
|
||
node_map[node_id] = node
|
||
|
||
triplets = []
|
||
skipped_count = 0
|
||
seen_ids = set()
|
||
|
||
for edge_tuple in edges:
|
||
if len(edge_tuple) < 4:
|
||
continue
|
||
|
||
source_node_id, target_node_id, relationship_name, edge_properties = (
|
||
edge_tuple[0],
|
||
edge_tuple[1],
|
||
edge_tuple[2],
|
||
edge_tuple[3],
|
||
)
|
||
|
||
source_node = node_map.get(str(source_node_id))
|
||
target_node = node_map.get(str(target_node_id))
|
||
|
||
if not source_node or not target_node or relationship_name is None:
|
||
skipped_count += 1
|
||
continue
|
||
|
||
source_node_text = _extract_embeddable_text_from_datapoint(source_node)
|
||
target_node_text = _extract_embeddable_text_from_datapoint(target_node)
|
||
|
||
relationship_text = ""
|
||
if isinstance(edge_properties, dict):
|
||
edge_text = edge_properties.get("edge_text")
|
||
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
||
relationship_text = edge_text.strip()
|
||
|
||
if not relationship_text and relationship_name:
|
||
relationship_text = relationship_name
|
||
|
||
if not source_node_text and not relationship_text and not relationship_name:
|
||
skipped_count += 1
|
||
continue
|
||
|
||
embeddable_text = f"{source_node_text} -› {relationship_text}-›{target_node_text}".strip()
|
||
|
||
triplet_id = generate_node_id(str(source_node_id) + relationship_name + str(target_node_id))
|
||
|
||
if triplet_id in seen_ids:
|
||
continue
|
||
seen_ids.add(triplet_id)
|
||
|
||
triplets.append(
|
||
Triplet(
|
||
id=triplet_id,
|
||
from_node_id=str(source_node_id),
|
||
to_node_id=str(target_node_id),
|
||
text=embeddable_text,
|
||
)
|
||
)
|
||
|
||
return triplets
|