Merge branch 'dev' into multi-tenant-neo4j

This commit is contained in:
Igor Ilic 2025-12-03 14:37:13 +01:00 committed by GitHub
commit 45f32f8bfd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 7521 additions and 5522 deletions

View file

@ -2005,3 +2005,134 @@ class KuzuAdapter(GraphDBInterface):
time_ids_list = [item[0] for item in time_nodes]
return ", ".join(f"'{uid}'" for uid in time_ids_list)
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
"""
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
Parameters:
-----------
- offset (int): Number of triplets to skip before returning results.
- limit (int): Maximum number of triplets to return.
Returns:
--------
- list[dict[str, Any]]: A list of triplets, where each triplet is a dictionary
with keys: 'start_node', 'relationship_properties', 'end_node'.
Raises:
-------
- ValueError: If offset or limit are negative.
- Exception: Re-raises any exceptions from query execution.
"""
if offset < 0:
raise ValueError(f"Offset must be non-negative, got {offset}")
if limit < 0:
raise ValueError(f"Limit must be non-negative, got {limit}")
query = """
MATCH (start_node:Node)-[relationship:EDGE]->(end_node:Node)
RETURN {
start_node: {
id: start_node.id,
name: start_node.name,
type: start_node.type,
properties: start_node.properties
},
relationship_properties: {
relationship_name: relationship.relationship_name,
properties: relationship.properties
},
end_node: {
id: end_node.id,
name: end_node.name,
type: end_node.type,
properties: end_node.properties
}
} AS triplet
SKIP $offset LIMIT $limit
"""
try:
results = await self.query(query, {"offset": offset, "limit": limit})
except Exception as e:
logger.error(f"Failed to execute triplet query: {str(e)}")
logger.error(f"Query: {query}")
logger.error(f"Parameters: offset={offset}, limit={limit}")
raise
triplets = []
for idx, row in enumerate(results):
try:
if not row or len(row) == 0:
logger.warning(f"Skipping empty row at index {idx} in triplet batch")
continue
if not isinstance(row[0], dict):
logger.warning(
f"Skipping invalid row at index {idx}: expected dict, got {type(row[0])}"
)
continue
triplet = row[0]
if "start_node" not in triplet:
logger.warning(f"Skipping triplet at index {idx}: missing 'start_node' key")
continue
if not isinstance(triplet["start_node"], dict):
logger.warning(f"Skipping triplet at index {idx}: 'start_node' is not a dict")
continue
triplet["start_node"] = self._parse_node_properties(triplet["start_node"].copy())
if "relationship_properties" not in triplet:
logger.warning(
f"Skipping triplet at index {idx}: missing 'relationship_properties' key"
)
continue
if not isinstance(triplet["relationship_properties"], dict):
logger.warning(
f"Skipping triplet at index {idx}: 'relationship_properties' is not a dict"
)
continue
rel_props = triplet["relationship_properties"].copy()
relationship_name = rel_props.get("relationship_name") or ""
if rel_props.get("properties"):
try:
parsed_props = json.loads(rel_props["properties"])
if isinstance(parsed_props, dict):
rel_props.update(parsed_props)
del rel_props["properties"]
else:
logger.warning(
f"Parsed relationship properties is not a dict for triplet at index {idx}"
)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(
f"Failed to parse relationship properties JSON for triplet at index {idx}: {e}"
)
rel_props["relationship_name"] = relationship_name
triplet["relationship_properties"] = rel_props
if "end_node" not in triplet:
logger.warning(f"Skipping triplet at index {idx}: missing 'end_node' key")
continue
if not isinstance(triplet["end_node"], dict):
logger.warning(f"Skipping triplet at index {idx}: 'end_node' is not a dict")
continue
triplet["end_node"] = self._parse_node_properties(triplet["end_node"].copy())
triplets.append(triplet)
except Exception as e:
logger.error(f"Error processing triplet at index {idx}: {e}", exc_info=True)
continue
return triplets

View file

@ -8,7 +8,7 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple
from typing import Optional, Any, List, Dict, Type, Tuple, Coroutine
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
@ -1527,3 +1527,25 @@ class Neo4jAdapter(GraphDBInterface):
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
return ", ".join(f"'{uid}'" for uid in time_ids_list)
async def get_triplets_batch(self, offset: int, limit: int) -> list[dict[str, Any]]:
"""
Retrieve a batch of triplets (start_node, relationship, end_node) from the graph.
Parameters:
-----------
- offset (int): Number of triplets to skip before returning results.
- limit (int): Maximum number of triplets to return.
Returns:
--------
- list[dict[str, Any]]: A list of triplets.
"""
query = f"""
MATCH (start_node:`{BASE_LABEL}`)-[relationship]->(end_node:`{BASE_LABEL}`)
RETURN start_node, properties(relationship) AS relationship_properties, end_node
SKIP $offset LIMIT $limit
"""
results = await self.query(query, {"offset": offset, "limit": limit})
return results

View file

@ -0,0 +1,53 @@
from typing import Any
from cognee import memify
from cognee.context_global_variables import (
set_database_global_context_variables,
)
from cognee.exceptions import CogneeValidationError
from cognee.modules.data.methods import get_authorized_existing_datasets
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.models import User
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
from cognee.tasks.storage import index_data_points
logger = get_logger("create_triplet_embeddings")
async def create_triplet_embeddings(
user: User,
dataset: str = "main_dataset",
run_in_background: bool = False,
triplets_batch_size: int = 100,
) -> dict[str, Any]:
dataset_to_write = await get_authorized_existing_datasets(
user=user, datasets=[dataset], permission_type="write"
)
if not dataset_to_write:
raise CogneeValidationError(
message=f"User does not have write access to dataset: {dataset}",
log=False,
)
await set_database_global_context_variables(
dataset_to_write[0].id, dataset_to_write[0].owner_id
)
extraction_tasks = [Task(get_triplet_datapoints, triplets_batch_size=triplets_batch_size)]
enrichment_tasks = [
Task(index_data_points, task_config={"batch_size": triplets_batch_size}),
]
result = await memify(
extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks,
dataset=dataset_to_write[0].id,
data=[{}],
user=user,
run_in_background=run_in_background,
)
return result

View file

@ -0,0 +1,9 @@
from cognee.infrastructure.engine import DataPoint
class Triplet(DataPoint):
text: str
from_node_id: str
to_node_id: str
metadata: dict = {"index_fields": ["text"]}

View file

@ -7,3 +7,4 @@ from .ColumnValue import ColumnValue
from .Timestamp import Timestamp
from .Interval import Interval
from .Event import Event
from .Triplet import Triplet

View file

@ -0,0 +1,182 @@
import asyncio
from typing import Any, Optional, Type, List
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
logger = get_logger("TripletRetriever")
class TripletRetriever(BaseRetriever):
"""
Retriever for handling LLM-based completion searches using triplets.
Public methods:
- get_context(query: str) -> str
- get_completion(query: str, context: Optional[Any] = None) -> Any
"""
def __init__(
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
async def get_context(self, query: str) -> str:
"""
Retrieves relevant triplets as context.
Fetches triplets based on a query from a vector engine and combines their text.
Returns empty string if no triplets are found. Raises NoDataError if the collection is not
found.
Parameters:
-----------
- query (str): The query string used to search for relevant triplets.
Returns:
--------
- str: A string containing the combined text of the retrieved triplets, or an
empty string if none are found.
"""
vector_engine = get_vector_engine()
try:
if not await vector_engine.has_collection(collection_name="Triplet_text"):
logger.error("Triplet_text collection not found")
raise NoDataError(
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
)
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
if len(found_triplets) == 0:
return ""
triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
combined_context = "\n".join(triplets_payload)
return combined_context
except CollectionNotFoundError as error:
logger.error("Triplet_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates an LLM completion using the context.
Retrieves context if not provided and generates a completion based on the query and
context using an external completion generator.
Parameters:
-----------
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- Any: The generated completion based on the provided query and context.
"""
if context is None:
context = await self.get_context(query)
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
completion = await self._get_completion_with_session(
query=query,
context=context,
session_id=session_id,
response_model=response_model,
)
else:
completion = await self._get_completion_without_session(
query=query,
context=context,
response_model=response_model,
)
return [completion]
async def _get_completion_with_session(
self,
query: str,
context: str,
session_id: Optional[str],
response_model: Type,
) -> Any:
"""Generate completion with session history and caching."""
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
conversation_history=conversation_history,
response_model=response_model,
),
)
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return completion
async def _get_completion_without_session(
self,
query: str,
context: str,
response_model: Type,
) -> Any:
"""Generate completion without session history."""
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
response_model=response_model,
)
return completion

View file

@ -2,6 +2,7 @@ import os
from typing import Callable, List, Optional, Type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import select_search_type
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
@ -61,6 +62,18 @@ async def get_search_type_tools(
system_prompt=system_prompt,
).get_context,
],
SearchType.TRIPLET_COMPLETION: [
TripletRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
).get_completion,
TripletRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
).get_context,
],
SearchType.GRAPH_COMPLETION: [
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,

View file

@ -5,6 +5,7 @@ class SearchType(Enum):
SUMMARIES = "SUMMARIES"
CHUNKS = "CHUNKS"
RAG_COMPLETION = "RAG_COMPLETION"
TRIPLET_COMPLETION = "TRIPLET_COMPLETION"
GRAPH_COMPLETION = "GRAPH_COMPLETION"
GRAPH_SUMMARY_COMPLETION = "GRAPH_SUMMARY_COMPLETION"
CODE = "CODE"

View file

@ -0,0 +1,283 @@
from typing import AsyncGenerator, Dict, Any, List, Optional
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Triplet
from cognee.tasks.storage import index_data_points
logger = get_logger("get_triplet_datapoints")
def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
"""
Build a mapping of DataPoint type names to their index_fields.
Returns:
--------
- Dict[str, List[str]]: Mapping of type name to list of index field names
"""
logger.debug("Building DataPoint type to index_fields mapping")
subclasses = get_all_subclasses(DataPoint)
datapoint_type_index_property = {}
for subclass in subclasses:
if "metadata" in subclass.model_fields:
metadata_field = subclass.model_fields["metadata"]
default = getattr(metadata_field, "default", None)
if isinstance(default, dict):
index_fields = default.get("index_fields", [])
if index_fields:
datapoint_type_index_property[subclass.__name__] = index_fields
logger.debug(
f"Registered {subclass.__name__} with index_fields: {index_fields}"
)
logger.info(
f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
f"{list(datapoint_type_index_property.keys())}"
)
return datapoint_type_index_property
def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
"""
Extract and concatenate embeddable properties from a node or edge dictionary.
Parameters:
-----------
- node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
- index_fields (List[str]): List of field names to extract and concatenate.
Returns:
--------
- str: Concatenated string of all embeddable property values, or empty string if none found.
"""
if not node_or_edge or not index_fields:
return ""
embeddable_values = []
for field_name in index_fields:
field_value = node_or_edge.get(field_name)
if field_value is not None:
field_value = str(field_value).strip()
if field_value:
embeddable_values.append(field_value)
return " ".join(embeddable_values) if embeddable_values else ""
def _extract_relationship_text(
relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
) -> str:
"""
Extract relationship text from edge properties.
Parameters:
-----------
- relationship (Dict[str, Any]): Dictionary containing relationship properties
- datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
Returns:
--------
- str: Extracted relationship text or empty string
"""
if not relationship:
return ""
edge_text = relationship.get("edge_text")
if edge_text and isinstance(edge_text, str) and edge_text.strip():
return edge_text.strip()
# Fallback to extracting from EdgeType index_fields
edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
return _extract_embeddable_text(relationship, edge_type_index_fields)
def _process_single_triplet(
triplet_datapoint: Dict[str, Any],
datapoint_type_index_property: Dict[str, List[str]],
offset: int,
idx: int,
) -> tuple[Optional[Triplet], Optional[str]]:
"""
Process a single triplet and create a Triplet object.
Parameters:
-----------
- triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
- datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
- offset (int): Current batch offset
- idx (int): Index within current batch
Returns:
--------
- tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
"""
start_node = triplet_datapoint.get("start_node", {})
end_node = triplet_datapoint.get("end_node", {})
relationship = triplet_datapoint.get("relationship_properties", {})
start_node_type = start_node.get("type")
end_node_type = end_node.get("type")
start_index_fields = datapoint_type_index_property.get(start_node_type, [])
end_index_fields = datapoint_type_index_property.get(end_node_type, [])
if not start_index_fields:
logger.debug(
f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
)
if not end_index_fields:
logger.debug(
f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
)
start_node_id = start_node.get("id", "")
end_node_id = end_node.get("id", "")
if not start_node_id or not end_node_id:
return None, (
f"Skipping triplet at offset {offset + idx}: missing node IDs "
f"(start: {start_node_id}, end: {end_node_id})"
)
relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
start_node_text = _extract_embeddable_text(start_node, start_index_fields)
end_node_text = _extract_embeddable_text(end_node, end_index_fields)
if not start_node_text and not end_node_text and not relationship_text:
return None, (
f"Skipping triplet at offset {offset + idx}: empty embeddable text "
f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
)
embeddable_text = f"{start_node_text}-{relationship_text}-{end_node_text}".strip()
triplet_obj = Triplet(from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text)
return triplet_obj, None
async def get_triplet_datapoints(
data,
triplets_batch_size: int = 100,
) -> AsyncGenerator[Triplet, None]:
"""
Async generator that yields batches of triplet datapoints with embeddable text extracted.
Each triplet in the batch includes:
- Original triplet structure (start_node, relationship_properties, end_node)
- Extracted embeddable text for each element based on index_fields
Parameters:
-----------
- triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
Yields:
-------
- List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
"""
if not data or data == [{}]:
logger.info("Fetching graph data for current user")
logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
graph_engine = await get_graph_engine()
graph_engine_type = type(graph_engine).__name__
logger.debug(f"Using graph engine: {graph_engine_type}")
if not hasattr(graph_engine, "get_triplets_batch"):
error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
logger.error(error_msg)
raise NotImplementedError(error_msg)
datapoint_type_index_property = _build_datapoint_type_index_mapping()
offset = 0
total_triplets_processed = 0
batch_number = 0
while True:
try:
batch_number += 1
logger.debug(
f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
)
triplets_batch = await graph_engine.get_triplets_batch(
offset=offset, limit=triplets_batch_size
)
if not triplets_batch:
logger.info(f"No more triplets found at offset {offset}. Processing complete.")
break
logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
triplet_datapoints = []
skipped_count = 0
for idx, triplet_datapoint in enumerate(triplets_batch):
try:
triplet_obj, error_msg = _process_single_triplet(
triplet_datapoint, datapoint_type_index_property, offset, idx
)
if error_msg:
logger.warning(error_msg)
skipped_count += 1
continue
if triplet_obj:
triplet_datapoints.append(triplet_obj)
yield triplet_obj
except Exception as e:
logger.warning(
f"Error processing triplet at offset {offset + idx}: {e}. "
f"Skipping this triplet and continuing."
)
skipped_count += 1
continue
if skipped_count > 0:
logger.warning(
f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
)
if not triplet_datapoints:
logger.warning(
f"No valid triplet datapoints in batch {batch_number} after processing"
)
offset += len(triplets_batch)
if len(triplets_batch) < triplets_batch_size:
break
continue
total_triplets_processed += len(triplet_datapoints)
logger.info(
f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
f"(total processed: {total_triplets_processed})"
)
offset += len(triplets_batch)
if len(triplets_batch) < triplets_batch_size:
logger.info(
f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
f"Processing complete."
)
break
except Exception as e:
logger.error(
f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
exc_info=True,
)
raise
logger.info(
f"Triplet datapoints extraction complete. "
f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
)

View file

@ -0,0 +1,84 @@
import os
import pytest
import pathlib
import pytest_asyncio
import cognee
from cognee.low_level import setup
from cognee.tasks.storage import add_data_points
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.modules.engine.models import Triplet
@pytest_asyncio.fixture
async def setup_test_environment_with_triplets():
"""Set up a clean test environment with triplets."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(base_dir / ".cognee_system/test_triplet_retriever_context_simple")
data_directory_path = str(base_dir / ".data_storage/test_triplet_retriever_context_simple")
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
triplet1 = Triplet(
from_node_id="node1",
to_node_id="node2",
text="Alice knows Bob",
)
triplet2 = Triplet(
from_node_id="node2",
to_node_id="node3",
text="Bob works at Tech Corp",
)
triplets = [triplet1, triplet2]
await add_data_points(triplets)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest_asyncio.fixture
async def setup_test_environment_empty():
"""Set up a clean test environment without triplets."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
system_directory_path = str(
base_dir / ".cognee_system/test_triplet_retriever_context_empty_collection"
)
data_directory_path = str(
base_dir / ".data_storage/test_triplet_retriever_context_empty_collection"
)
cognee.config.system_root_directory(system_directory_path)
cognee.config.data_root_directory(data_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
yield
try:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
except Exception:
pass
@pytest.mark.asyncio
async def test_triplet_retriever_context_simple(setup_test_environment_with_triplets):
"""Integration test: verify TripletRetriever can retrieve triplet context."""
retriever = TripletRetriever(top_k=5)
context = await retriever.get_context("Alice")
assert "Alice knows Bob" in context, "Failed to get Alice triplet"

View file

@ -0,0 +1,69 @@
import os
import pathlib
import pytest
import pytest_asyncio
from unittest.mock import AsyncMock, patch
import cognee
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
from cognee.modules.engine.models import Triplet
@pytest_asyncio.fixture
async def setup_test_environment():
"""Set up a clean test environment with a simple graph."""
base_dir = pathlib.Path(__file__).parent.parent.parent.parent
data_directory_path = str(base_dir / ".data_storage/test_get_triplet_datapoints_integration")
cognee_directory_path = str(base_dir / ".cognee_system/test_get_triplet_datapoints_integration")
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "test_triplets"
text = "Volkswagen is a german car manufacturer from Wolfsburg. They produce different models such as Golf, Polo and Touareg."
await cognee.add(text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.mark.asyncio
async def test_get_triplet_datapoints_integration(setup_test_environment):
"""Integration test: verify get_triplet_datapoints works with real graph data."""
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
if not hasattr(graph_engine, "get_triplets_batch"):
pytest.skip("Graph engine does not support get_triplets_batch")
triplets = []
with patch(
"cognee.tasks.memify.get_triplet_datapoints.index_data_points", new_callable=AsyncMock
):
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=10):
triplets.append(triplet)
nodes, edges = await graph_engine.get_graph_data()
if len(edges) > 0 and len(triplets) == 0:
test_triplets = await graph_engine.get_triplets_batch(offset=0, limit=10)
if len(test_triplets) == 0:
pytest.fail(
f"Edges exist in graph ({len(edges)} edges) but get_triplets_batch found none. "
f"This indicates the query pattern may not match the graph structure."
)
for triplet in triplets:
assert isinstance(triplet, Triplet), "Each item should be a Triplet instance"
assert triplet.from_node_id, "Triplet should have from_node_id"
assert triplet.to_node_id, "Triplet should have to_node_id"
assert triplet.text, "Triplet should have embeddable text"

View file

@ -8,10 +8,10 @@ Tests all retrievers that save conversation history to Redis cache:
4. GRAPH_COMPLETION_CONTEXT_EXTENSION
5. GRAPH_SUMMARY_COMPLETION
6. TEMPORAL
7. TRIPLET_COMPLETION
"""
import os
import shutil
import cognee
import pathlib
@ -63,6 +63,10 @@ async def main():
user = await get_default_user()
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
await create_triplet_embeddings(user=user, dataset=dataset_name)
cache_engine = get_cache_engine()
assert cache_engine is not None, "Cache engine should be available for testing"
@ -216,6 +220,24 @@ async def main():
]
assert len(our_qa_temporal) == 1, "Should find Temporal question in history"
session_id_triplet = "test_session_triplet"
result_triplet = await cognee.search(
query_type=SearchType.TRIPLET_COMPLETION,
query_text="What companies are mentioned?",
session_id=session_id_triplet,
)
assert isinstance(result_triplet, list) and len(result_triplet) > 0, (
f"TRIPLET_COMPLETION should return non-empty list, got: {result_triplet!r}"
)
history_triplet = await cache_engine.get_latest_qa(str(user.id), session_id_triplet, last_n=10)
our_qa_triplet = [
h for h in history_triplet if h["question"] == "What companies are mentioned?"
]
assert len(our_qa_triplet) == 1, "Should find Triplet question in history"
from cognee.modules.retrieval.utils.session_cache import (
get_conversation_history,
)

View file

@ -2,6 +2,7 @@ import pathlib
import os
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -12,8 +13,10 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from collections import Counter
logger = get_logger()
@ -37,6 +40,23 @@ async def main():
await cognee.cognify([dataset_name])
user = await get_default_user()
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5)
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
vector_engine = get_vector_engine()
collection = await vector_engine.search(
query_text="Test", limit=None, collection_name="Triplet_text"
)
assert len(edges) == len(collection), (
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
)
context_gk = await GraphCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
@ -49,6 +69,9 @@ async def main():
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_triplet = await TripletRetriever().get_context(
query="Next to which country is Germany located?"
)
for name, context in [
("GraphCompletionRetriever", context_gk),
@ -65,6 +88,13 @@ async def main():
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
lower_triplet = context_triplet.lower()
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
)
triplets_gk = await GraphCompletionRetriever().get_triplets(
query="Next to which country is Germany located?"
)
@ -129,6 +159,11 @@ async def main():
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_triplet = await cognee.search(
query_type=SearchType.TRIPLET_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=True,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
@ -141,6 +176,7 @@ async def main():
("GRAPH_COMPLETION_COT", completion_cot),
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
("GRAPH_SUMMARY_COMPLETION", completion_sum),
("TRIPLET_COMPLETION", completion_triplet),
]:
assert isinstance(search_results, list), f"{name}: should return a list"
assert len(search_results) == 1, (
@ -168,7 +204,7 @@ async def main():
# Assert there are exactly 4 CogneeUserInteraction nodes.
assert type_counts.get("CogneeUserInteraction", 0) == 4, (
f"Expected exactly four DCogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}"
)
# Assert there is exactly two CogneeUserFeedback nodes.

View file

@ -0,0 +1,214 @@
import sys
import pytest
from unittest.mock import AsyncMock, patch
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
from cognee.modules.engine.models import Triplet
from cognee.modules.engine.models.Entity import Entity
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.models.EdgeType import EdgeType
get_triplet_datapoints_module = sys.modules["cognee.tasks.memify.get_triplet_datapoints"]
@pytest.fixture
def mock_graph_engine():
"""Create a mock graph engine with get_triplets_batch method."""
engine = AsyncMock()
engine.get_triplets_batch = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_triplet_datapoints_success(mock_graph_engine):
"""Test successful extraction of triplet datapoints."""
mock_triplets_batch = [
{
"start_node": {
"id": "node1",
"type": "Entity",
"name": "Alice",
"description": "A person",
},
"end_node": {
"id": "node2",
"type": "Entity",
"name": "Bob",
"description": "Another person",
},
"relationship_properties": {
"relationship_name": "knows",
},
}
]
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
with (
patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
),
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
):
mock_get_subclasses.return_value = [Triplet, EdgeType, Entity]
triplets = []
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)
assert len(triplets) == 1
assert isinstance(triplets[0], Triplet)
assert triplets[0].from_node_id == "node1"
assert triplets[0].to_node_id == "node2"
assert "Alice" in triplets[0].text
assert "knows" in triplets[0].text
assert "Bob" in triplets[0].text
@pytest.mark.asyncio
async def test_get_triplet_datapoints_edge_text_priority_and_fallback(mock_graph_engine):
"""Test that edge_text is prioritized over relationship_name, and fallback works."""
class MockEntity(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
mock_triplets_batch = [
{
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
"relationship_properties": {
"relationship_name": "knows",
"edge_text": "has a close friendship with",
},
},
{
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
"relationship_properties": {
"relationship_name": "works_with",
},
},
]
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
with (
patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
),
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
):
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
triplets = []
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)
assert len(triplets) == 2
assert "has a close friendship with" in triplets[0].text
assert "knows" not in triplets[0].text
assert "works_with" in triplets[1].text
@pytest.mark.asyncio
async def test_get_triplet_datapoints_skips_missing_node_ids(mock_graph_engine):
"""Test that triplets with missing node IDs are skipped."""
class MockEntity(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
mock_triplets_batch = [
{
"start_node": {"id": "", "type": "Entity", "name": "Alice"},
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
"relationship_properties": {"relationship_name": "knows"},
},
{
"start_node": {"id": "node3", "type": "Entity", "name": "Charlie"},
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
"relationship_properties": {"relationship_name": "works_with"},
},
]
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
with (
patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
),
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
):
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
triplets = []
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)
assert len(triplets) == 1
assert triplets[0].from_node_id == "node3"
@pytest.mark.asyncio
async def test_get_triplet_datapoints_error_handling(mock_graph_engine):
"""Test that errors are handled correctly - invalid data is skipped, query errors propagate."""
class MockEntity(DataPoint):
name: str
metadata: dict = {"index_fields": ["name"]}
mock_triplets_batch = [
{
"start_node": {"id": "node1", "type": "Entity", "name": "Alice"},
"end_node": {"id": "node2", "type": "Entity", "name": "Bob"},
"relationship_properties": {"relationship_name": "knows"},
},
{
"start_node": None,
"end_node": {"id": "node4", "type": "Entity", "name": "Diana"},
"relationship_properties": {"relationship_name": "works_with"},
},
]
mock_graph_engine.get_triplets_batch.return_value = mock_triplets_batch
with (
patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
),
patch.object(get_triplet_datapoints_module, "get_all_subclasses") as mock_get_subclasses,
):
mock_get_subclasses.return_value = [Triplet, EdgeType, MockEntity]
triplets = []
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)
assert len(triplets) == 1
assert triplets[0].from_node_id == "node1"
mock_graph_engine.get_triplets_batch.side_effect = Exception("Database connection error")
with patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
):
triplets = []
with pytest.raises(Exception, match="Database connection error"):
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)
@pytest.mark.asyncio
async def test_get_triplet_datapoints_no_get_triplets_batch_method(mock_graph_engine):
"""Test that NotImplementedError is raised when graph engine lacks get_triplets_batch."""
del mock_graph_engine.get_triplets_batch
with patch.object(
get_triplet_datapoints_module, "get_graph_engine", return_value=mock_graph_engine
):
triplets = []
with pytest.raises(NotImplementedError, match="does not support get_triplets_batch"):
async for triplet in get_triplet_datapoints([{}], triplets_batch_size=100):
triplets.append(triplet)

View file

@ -0,0 +1,83 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@pytest.fixture
def mock_vector_engine():
"""Create a mock vector engine."""
engine = AsyncMock()
engine.has_collection = AsyncMock(return_value=True)
engine.search = AsyncMock()
return engine
@pytest.mark.asyncio
async def test_get_context_success(mock_vector_engine):
"""Test successful retrieval of triplet context."""
mock_result1 = MagicMock()
mock_result1.payload = {"text": "Alice knows Bob"}
mock_result2 = MagicMock()
mock_result2.payload = {"text": "Bob works at Tech Corp"}
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
retriever = TripletRetriever(top_k=5)
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
@pytest.mark.asyncio
async def test_get_context_no_collection(mock_vector_engine):
"""Test that NoDataError is raised when Triplet_text collection doesn't exist."""
mock_vector_engine.has_collection.return_value = False
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="create_triplet_embeddings"):
await retriever.get_context("test query")
@pytest.mark.asyncio
async def test_get_context_empty_results(mock_vector_engine):
"""Test that empty string is returned when no triplets are found."""
mock_vector_engine.search.return_value = []
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
context = await retriever.get_context("test query")
assert context == ""
@pytest.mark.asyncio
async def test_get_context_collection_not_found_error(mock_vector_engine):
"""Test that CollectionNotFoundError is converted to NoDataError."""
mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found")
retriever = TripletRetriever()
with patch(
"cognee.modules.retrieval.triplet_retriever.get_vector_engine",
return_value=mock_vector_engine,
):
with pytest.raises(NoDataError, match="No data found"):
await retriever.get_context("test query")

View file

@ -0,0 +1,79 @@
import asyncio
import cognee
from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.modules.engine.operations.setup import setup
text_1 = """
1. Audi
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
2. BMW
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
3. Mercedes-Benz
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
4. Porsche
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
5. Volkswagen
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
"""
text_2 = """
1. Apple
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
2. Google
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
3. Microsoft
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
4. Amazon
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
5. Meta
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
"""
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
await cognee.add([text_1, text_2])
await cognee.cognify()
default_user = await get_default_user()
await create_triplet_embeddings(
user=default_user,
triplets_batch_size=100,
)
search_results = await cognee.search(
query_type=SearchType.TRIPLET_COMPLETION,
query_text="What are the models produced by Volkswagen based on the context?",
)
print(search_results)
if __name__ == "__main__":
logger = setup_logging(log_level=INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

8047
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,7 @@ classifiers = [
dependencies = [
"openai>=1.80.1",
"python-dotenv>=1.0.1,<2.0.0",
"pydantic>=2.10.5,<3.0.0",
"pydantic>=2.10.5,<2.12.0",
"pydantic-settings>=2.2.1,<3",
"typing_extensions>=4.12.2,<5.0.0",
"numpy>=1.26.4, <=4.0.0",
@ -33,7 +33,7 @@ dependencies = [
"instructor>=1.9.1,<2.0.0",
"filetype>=1.2.0,<2.0.0",
"aiohttp>=3.11.14,<4.0.0",
"aiofiles>=23.2.1,<24.0.0",
"aiofiles>=23.2.1",
"rdflib>=7.1.4,<7.2.0",
"pypdf>=4.1.0,<7.0.0",
"jinja2>=3.1.3,<4",
@ -199,8 +199,3 @@ exclude = [
[tool.ruff.lint]
ignore = ["F401"]
[dependency-groups]
dev = [
"pytest-timeout>=2.4.0",
]

3699
uv.lock generated

File diff suppressed because it is too large Load diff