<!-- .github/pull_request_template.md -->
## Description
This PR introduces triplet embeddings via a new
create_triplet_embeddings memify pipeline.
The pipeline reads the graph in batches, extracts properties from graph
elements based on their datapoint types, and generates combined triplet
embeddings. These embeddings are stored in the vector database as a new
collection.
Changes in This PR:
-Added a new create_triplet_embeddings memify pipeline.
-Added a new get_triplet_datapoints memify task.
-Introduced a new triplet_completion search type.
-Added full test coverage
--Unit tests: memify task, pipeline, and retriever
--Integration tests: memify task, pipeline, and retriever
--End-to-end tests: updated session history tests and multi-DB search
tests; added tests for triplet_completion and memify pipeline execution
Acceptance Criteria and Testing
Scenario 1:
-Run default add, cognify pipelines
-Run create triplet embeddings memify pipeline
-Verify the vector DB contains a non empty Triplet_text collection.
-Use the new triplet_completion search type and confirm it works
correctly.
Scenario 2:
-Run the default add and cognify pipelines.
-Do not run the triplet embeddings memify pipeline.
-Attempt to use the triplet_completion search type.
-You should receive an error indicating that the triplet embeddings
memify pipeline must be executed first.
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Triplet-based search with LLM-powered completions (TRIPLET_COMPLETION)
* Batch triplet retrieval and a triplet embeddings pipeline for
extraction, indexing, and optional background processing
* Context retrieval from triplet embeddings with optional caching and
conversation-history support
* New Triplet data type exposed for indexing and search
* **Examples**
* End-to-end example demonstrating triplet embeddings extraction and
TRIPLET_COMPLETION search
* **Tests**
* Unit and integration tests covering triplet extraction, retrieval,
embedding pipeline, and completion flows
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
283 lines
10 KiB
Python
283 lines
10 KiB
Python
from typing import AsyncGenerator, Dict, Any, List, Optional
|
||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||
from cognee.shared.logging_utils import get_logger
|
||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||
from cognee.infrastructure.engine import DataPoint
|
||
from cognee.modules.engine.models import Triplet
|
||
from cognee.tasks.storage import index_data_points
|
||
|
||
logger = get_logger("get_triplet_datapoints")
|
||
|
||
|
||
def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
|
||
"""
|
||
Build a mapping of DataPoint type names to their index_fields.
|
||
|
||
Returns:
|
||
--------
|
||
- Dict[str, List[str]]: Mapping of type name to list of index field names
|
||
"""
|
||
logger.debug("Building DataPoint type to index_fields mapping")
|
||
subclasses = get_all_subclasses(DataPoint)
|
||
datapoint_type_index_property = {}
|
||
|
||
for subclass in subclasses:
|
||
if "metadata" in subclass.model_fields:
|
||
metadata_field = subclass.model_fields["metadata"]
|
||
default = getattr(metadata_field, "default", None)
|
||
if isinstance(default, dict):
|
||
index_fields = default.get("index_fields", [])
|
||
if index_fields:
|
||
datapoint_type_index_property[subclass.__name__] = index_fields
|
||
logger.debug(
|
||
f"Registered {subclass.__name__} with index_fields: {index_fields}"
|
||
)
|
||
|
||
logger.info(
|
||
f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
|
||
f"{list(datapoint_type_index_property.keys())}"
|
||
)
|
||
return datapoint_type_index_property
|
||
|
||
|
||
def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
|
||
"""
|
||
Extract and concatenate embeddable properties from a node or edge dictionary.
|
||
|
||
Parameters:
|
||
-----------
|
||
- node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
|
||
- index_fields (List[str]): List of field names to extract and concatenate.
|
||
|
||
Returns:
|
||
--------
|
||
- str: Concatenated string of all embeddable property values, or empty string if none found.
|
||
"""
|
||
if not node_or_edge or not index_fields:
|
||
return ""
|
||
|
||
embeddable_values = []
|
||
for field_name in index_fields:
|
||
field_value = node_or_edge.get(field_name)
|
||
if field_value is not None:
|
||
field_value = str(field_value).strip()
|
||
|
||
if field_value:
|
||
embeddable_values.append(field_value)
|
||
|
||
return " ".join(embeddable_values) if embeddable_values else ""
|
||
|
||
|
||
def _extract_relationship_text(
|
||
relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
|
||
) -> str:
|
||
"""
|
||
Extract relationship text from edge properties.
|
||
|
||
Parameters:
|
||
-----------
|
||
- relationship (Dict[str, Any]): Dictionary containing relationship properties
|
||
- datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
|
||
|
||
Returns:
|
||
--------
|
||
- str: Extracted relationship text or empty string
|
||
"""
|
||
if not relationship:
|
||
return ""
|
||
|
||
edge_text = relationship.get("edge_text")
|
||
if edge_text and isinstance(edge_text, str) and edge_text.strip():
|
||
return edge_text.strip()
|
||
|
||
# Fallback to extracting from EdgeType index_fields
|
||
edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
|
||
return _extract_embeddable_text(relationship, edge_type_index_fields)
|
||
|
||
|
||
def _process_single_triplet(
|
||
triplet_datapoint: Dict[str, Any],
|
||
datapoint_type_index_property: Dict[str, List[str]],
|
||
offset: int,
|
||
idx: int,
|
||
) -> tuple[Optional[Triplet], Optional[str]]:
|
||
"""
|
||
Process a single triplet and create a Triplet object.
|
||
|
||
Parameters:
|
||
-----------
|
||
- triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
|
||
- datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
|
||
- offset (int): Current batch offset
|
||
- idx (int): Index within current batch
|
||
|
||
Returns:
|
||
--------
|
||
- tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
|
||
"""
|
||
start_node = triplet_datapoint.get("start_node", {})
|
||
end_node = triplet_datapoint.get("end_node", {})
|
||
relationship = triplet_datapoint.get("relationship_properties", {})
|
||
|
||
start_node_type = start_node.get("type")
|
||
end_node_type = end_node.get("type")
|
||
|
||
start_index_fields = datapoint_type_index_property.get(start_node_type, [])
|
||
end_index_fields = datapoint_type_index_property.get(end_node_type, [])
|
||
|
||
if not start_index_fields:
|
||
logger.debug(
|
||
f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
|
||
)
|
||
if not end_index_fields:
|
||
logger.debug(
|
||
f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
|
||
)
|
||
|
||
start_node_id = start_node.get("id", "")
|
||
end_node_id = end_node.get("id", "")
|
||
|
||
if not start_node_id or not end_node_id:
|
||
return None, (
|
||
f"Skipping triplet at offset {offset + idx}: missing node IDs "
|
||
f"(start: {start_node_id}, end: {end_node_id})"
|
||
)
|
||
|
||
relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
|
||
start_node_text = _extract_embeddable_text(start_node, start_index_fields)
|
||
end_node_text = _extract_embeddable_text(end_node, end_index_fields)
|
||
|
||
if not start_node_text and not end_node_text and not relationship_text:
|
||
return None, (
|
||
f"Skipping triplet at offset {offset + idx}: empty embeddable text "
|
||
f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
|
||
)
|
||
|
||
embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
|
||
|
||
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
|
||
|
||
return triplet_obj, None
|
||
|
||
|
||
async def get_triplet_datapoints(
|
||
data,
|
||
triplets_batch_size: int = 100,
|
||
) -> AsyncGenerator[Triplet, None]:
|
||
"""
|
||
Async generator that yields batches of triplet datapoints with embeddable text extracted.
|
||
|
||
Each triplet in the batch includes:
|
||
- Original triplet structure (start_node, relationship_properties, end_node)
|
||
- Extracted embeddable text for each element based on index_fields
|
||
|
||
Parameters:
|
||
-----------
|
||
- triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
|
||
|
||
Yields:
|
||
-------
|
||
- List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
|
||
"""
|
||
if not data or data == [{}]:
|
||
logger.info("Fetching graph data for current user")
|
||
|
||
logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
|
||
|
||
graph_engine = await get_graph_engine()
|
||
graph_engine_type = type(graph_engine).__name__
|
||
logger.debug(f"Using graph engine: {graph_engine_type}")
|
||
|
||
if not hasattr(graph_engine, "get_triplets_batch"):
|
||
error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
|
||
logger.error(error_msg)
|
||
raise NotImplementedError(error_msg)
|
||
|
||
datapoint_type_index_property = _build_datapoint_type_index_mapping()
|
||
|
||
offset = 0
|
||
total_triplets_processed = 0
|
||
batch_number = 0
|
||
|
||
while True:
|
||
try:
|
||
batch_number += 1
|
||
logger.debug(
|
||
f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
|
||
)
|
||
|
||
triplets_batch = await graph_engine.get_triplets_batch(
|
||
offset=offset, limit=triplets_batch_size
|
||
)
|
||
|
||
if not triplets_batch:
|
||
logger.info(f"No more triplets found at offset {offset}. Processing complete.")
|
||
break
|
||
|
||
logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
|
||
|
||
triplet_datapoints = []
|
||
skipped_count = 0
|
||
|
||
for idx, triplet_datapoint in enumerate(triplets_batch):
|
||
try:
|
||
triplet_obj, error_msg = _process_single_triplet(
|
||
triplet_datapoint, datapoint_type_index_property, offset, idx
|
||
)
|
||
|
||
if error_msg:
|
||
logger.warning(error_msg)
|
||
skipped_count += 1
|
||
continue
|
||
|
||
if triplet_obj:
|
||
triplet_datapoints.append(triplet_obj)
|
||
yield triplet_obj
|
||
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Error processing triplet at offset {offset + idx}: {e}. "
|
||
f"Skipping this triplet and continuing."
|
||
)
|
||
skipped_count += 1
|
||
continue
|
||
|
||
if skipped_count > 0:
|
||
logger.warning(
|
||
f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
|
||
)
|
||
|
||
if not triplet_datapoints:
|
||
logger.warning(
|
||
f"No valid triplet datapoints in batch {batch_number} after processing"
|
||
)
|
||
offset += len(triplets_batch)
|
||
if len(triplets_batch) < triplets_batch_size:
|
||
break
|
||
continue
|
||
|
||
total_triplets_processed += len(triplet_datapoints)
|
||
logger.info(
|
||
f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
|
||
f"(total processed: {total_triplets_processed})"
|
||
)
|
||
|
||
offset += len(triplets_batch)
|
||
if len(triplets_batch) < triplets_batch_size:
|
||
logger.info(
|
||
f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
|
||
f"Processing complete."
|
||
)
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
|
||
exc_info=True,
|
||
)
|
||
raise
|
||
|
||
logger.info(
|
||
f"Triplet datapoints extraction complete. "
|
||
f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
|
||
)
|