Merge branch 'dev' into issue-1767

This commit is contained in:
Rajeev Rajeshuni 2025-11-28 17:24:55 +05:30 committed by GitHub
commit 57195fb5a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1482 additions and 70 deletions

View file

@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results

View file

@ -0,0 +1,29 @@
FROM python:3.11-slim
# Set environment variables
ENV PIP_NO_CACHE_DIR=true
ENV PATH="${PATH}:/root/.poetry/bin"
ENV PYTHONPATH=/app
ENV SKIP_MIGRATIONS=true
# System dependencies
RUN apt-get update && apt-get install -y \
gcc \
libpq-dev \
git \
curl \
build-essential \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY pyproject.toml poetry.lock README.md /app/
RUN pip install poetry
RUN poetry config virtualenvs.create false
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
COPY cognee/ /app/cognee
COPY distributed/ /app/distributed

View file

@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
############
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
if isinstance(retrieval_context, list):
retrieval_context = await retriever.convert_retrieved_objects_to_context(
triplets=retrieval_context
)
if isinstance(search_results, str):
search_results = [search_results]
#############
answer = {
"question": query_text,
"answer": search_results[0],

View file

@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")

View file

@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
deepeval_model: str = "gpt-5-mini"
deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True

View file

@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
import pathlib
from os import path
from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
image = (
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
.copy_local_file("pyproject.toml", "pyproject.toml")
.copy_local_file("poetry.lock", "poetry.lock")
.env(
{
"ENV": os.getenv("ENV"),
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
}
)
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
image = Image.from_dockerfile(
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
force_build=False,
).add_local_python_source("cognee")
@app.function(
image=image,
max_containers=10,
timeout=86400,
volumes={"/data": vol},
secrets=[modal.Secret.from_name("eval_secrets")],
)
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
benchmark="HotPotQA",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
answering_questions=True,
evaluating_answers=True,
calculate_metrics=True,
dashboard=True,
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,

View file

@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
@abstractmethod
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve nodes and edges filtered by the provided attribute criteria.
Parameters:
-----------
- attribute_filters: A list of dictionaries where keys are attribute names and values
are lists of attribute values to filter by.
"""
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
import time
start_time = time.time()
try:
nodes_query = """
MATCH (n:Node)
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
)
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
Returns:
nodes: List[dict] -> Each dict includes "id" and all node properties
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
if not all(isinstance(x, str) for x in target_ids):
raise CogneeValidationError("target_ids must be a list of strings")
query = """
MATCH (n:Node)-[r]->(m:Node)
WHERE n.id IN $target_ids OR m.id IN $target_ids
RETURN n.id, {
name: n.name,
type: n.type,
properties: n.properties
}, m.id, {
name: m.name,
type: m.type,
properties: m.properties
}, r.relationship_name, r.properties
"""
result = await self.query(query, {"target_ids": target_ids})
if not result:
logger.info("No data returned for the supplied IDs")
return [], []
nodes_dict = {}
edges = []
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
if n_props.get("properties"):
try:
additional_props = json.loads(n_props["properties"])
n_props.update(additional_props)
del n_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {n_id}")
if m_props.get("properties"):
try:
additional_props = json.loads(m_props["properties"])
m_props.update(additional_props)
del m_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {m_id}")
nodes_dict[n_id] = (n_id, n_props)
nodes_dict[m_id] = (m_id, m_props)
edge_props = {}
if r_props_raw:
try:
edge_props = json.loads(r_props_raw)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
source_id = edge_props.get("source_node_id", n_id)
target_id = edge_props.get("target_node_id", m_id)
edges.append((source_id, target_id, r_type, edge_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.

View file

@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
This version uses a single Cypher query for efficiency.
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
query = """
MATCH ()-[r]-()
WHERE startNode(r).id IN $target_ids
OR endNode(r).id IN $target_ids
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
RETURN
properties(a) AS n_properties,
properties(b) AS m_properties,
type(r) AS type,
properties(r) AS properties
"""
result = await self.query(query, {"target_ids": target_ids})
nodes_dict = {}
edges = []
for record in result:
n_props = record["n_properties"]
m_props = record["m_properties"]
r_props = record["properties"]
r_type = record["type"]
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
source_id = r_props.get("source_node_id", n_props["id"])
target_id = r_props.get("target_node_id", m_props["id"])
edges.append((source_id, target_id, r_type, r_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:

View file

@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
async def _get_nodeset_subgraph(
self,
adapter,
node_type,
node_name,
):
"""Retrieve subgraph based on node type and name."""
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodeset projected from the database."
)
return nodes_data, edges_data
async def _get_full_or_id_filtered_graph(
self,
adapter,
relevant_ids_to_filter,
):
"""Retrieve full or ID-filtered graph with fallback."""
if relevant_ids_to_filter is None:
logger.info("Retrieving full graph.")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
return nodes_data, edges_data
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
logger.info("Retrieving ID-filtered graph from database.")
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
else:
logger.info("Retrieving full graph from database.")
nodes_data, edges_data = await get_graph_data_fn()
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
logger.warning(
"Id filtered graph returned empty, falling back to full graph retrieval."
)
logger.info("Retrieving full graph")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError("Empty graph projected from the database.")
return nodes_data, edges_data
async def _get_filtered_graph(
self,
adapter,
memory_fragment_filter,
):
"""Retrieve graph filtered by attributes."""
logger.info("Retrieving graph filtered by memory fragment")
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
return nodes_data, edges_data
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await self._get_nodeset_subgraph(
adapter, node_type, node_name
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
adapter, relevant_ids_to_filter
)
else:
nodes_data, edges_data = await self._get_filtered_graph(
adapter, memory_fragment_filter
)
import time
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
self.add_node(
Node(
str(node_id),
node_attributes,
dimension=node_dimension,
node_penalty=triplet_distance_penalty,
)
)
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)

View file

@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
self,
node_id: str,
attributes: Optional[Dict[str, Any]] = None,
dimension: int = 1,
node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)

View file

@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(

View file

@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path

View file

@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
wide_search_top_k=self.wide_search_top_k,
triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
context = await self.resolve_edges_to_text(triplets)
return context
async def get_completion(
self,
query: str,

View file

@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path

View file

@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@ -74,6 +76,8 @@ async def get_memory_fragment(
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project, node_type=node_type, node_name=node_name
)
# Setting wide search limit based on the parameters
non_global_search = node_name is None
wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
projection_time = time.time() - start_time
vector_collection_search_time = time.time() - start_time
logger.info(
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
if wide_search_limit is not None:
relevant_ids_to_filter = list(
{
str(getattr(scored_node, "id"))
for collection_name, score_collection in node_distances.items()
if collection_name != "EdgeType_relationship_name"
and isinstance(score_collection, (list, tuple))
for scored_node in score_collection
if getattr(scored_node, "id", None)
}
)
else:
relevant_ids_to_filter = None
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances

View file

@ -37,6 +37,8 @@ async def get_search_type_tools(
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
@ -67,6 +69,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -75,6 +79,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
@ -85,6 +91,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@ -93,6 +101,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@ -103,6 +113,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@ -111,6 +123,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
@ -121,6 +135,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -129,6 +145,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
@ -145,8 +163,16 @@ async def get_search_type_tools(
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
TemporalRetriever(top_k=top_k).get_completion,
TemporalRetriever(top_k=top_k).get_context,
TemporalRetriever(
top_k=top_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
TemporalRetriever(
top_k=top_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [

View file

@ -24,6 +24,8 @@ async def no_access_control_search(
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@ -35,6 +37,8 @@ async def no_access_control_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()

View file

@ -47,6 +47,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@ -90,6 +92,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
else:
search_results = [
@ -105,6 +109,8 @@ async def search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
]
@ -219,6 +225,8 @@ async def authorized_search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@ -246,6 +254,8 @@ async def authorized_search(
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
@ -267,6 +277,8 @@ async def authorized_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@ -306,6 +318,7 @@ async def authorized_search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
)
return search_results
@ -325,6 +338,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -345,6 +360,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -378,6 +395,8 @@ async def search_in_datasets_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@ -413,6 +432,8 @@ async def search_in_datasets_context(
only_context=only_context,
context=context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
)

View file

@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1"
assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
assert len(node.status) == 2
assert np.all(node.status == 1)
@ -96,7 +96,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1
assert edge.node2 == node2
assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
assert edge.directed is False
assert len(edge.status) == 2
assert np.all(edge.status == 1)

View file

@ -1,4 +1,5 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@ -11,6 +12,30 @@ def setup_graph():
return CogneeGraph()
@pytest.fixture
def mock_adapter():
"""Fixture to create a mock adapter for database operations."""
adapter = AsyncMock()
return adapter
@pytest.fixture
def mock_vector_engine():
"""Fixture to create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
def test_add_node_success(setup_graph):
"""Test successful addition of a node."""
graph = setup_graph
@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
graph = setup_graph
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges_from_node("nonexistent")
@pytest.mark.asyncio
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
"""Test projecting a full graph from database."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1", "description": "First node"}),
("2", {"name": "Node2", "description": "Second node"}),
]
edges_data = [
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name", "description"],
edge_properties_to_project=["relationship_name"],
)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph.get_node("1") is not None
assert graph.get_node("2") is not None
assert graph.edges[0].node1.id == "1"
assert graph.edges[0].node2.id == "2"
@pytest.mark.asyncio
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
"""Test projecting an ID-filtered graph from database."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1"}),
("2", {"name": "Node2"}),
]
edges_data = [
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=["relationship_name"],
relevant_ids_to_filter=["1", "2"],
)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
mock_adapter.get_id_filtered_graph_data.assert_called_once()
@pytest.mark.asyncio
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
"""Test projecting a nodeset subgraph filtered by node type and name."""
graph = setup_graph
nodes_data = [
("1", {"name": "Alice", "type": "Person"}),
("2", {"name": "Bob", "type": "Person"}),
]
edges_data = [
("1", "2", "KNOWS", {"relationship_name": "knows"}),
]
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name", "type"],
edge_properties_to_project=["relationship_name"],
node_type="Person",
node_name=["Alice"],
)
assert len(graph.nodes) == 2
assert graph.get_node("1") is not None
assert len(graph.edges) == 1
mock_adapter.get_nodeset_subgraph.assert_called_once()
@pytest.mark.asyncio
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
"""Test projecting empty graph raises EntityNotFoundError."""
graph = setup_graph
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=[],
)
@pytest.mark.asyncio
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
"""Test that edges referencing missing nodes raise error."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1"}),
]
edges_data = [
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=["relationship_name"],
)
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_nodes(setup_graph):
"""Test mapping vector distances to graph nodes."""
graph = setup_graph
node1 = Node("1", {"name": "Node1"})
node2 = Node("2", {"name": "Node2"})
graph.add_node(node1)
graph.add_node(node2)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
@pytest.mark.asyncio
async def test_map_vector_distances_partial_node_coverage(setup_graph):
"""Test mapping vector distances when only some nodes have results."""
graph = setup_graph
node1 = Node("1", {"name": "Node1"})
node2 = Node("2", {"name": "Node2"})
node3 = Node("3", {"name": "Node3"})
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_multiple_categories(setup_graph):
"""Test mapping vector distances from multiple collection categories."""
graph = setup_graph
# Create nodes
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
node4 = Node("4")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
graph.add_node(node4)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
],
"TextSummary_text": [
MockScoredResult("3", 0.92),
],
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
"""Test mapping vector distances to edges when edge_distances provided."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
"""Test mapping edge distances when searching for them."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
mock_vector_engine.search.return_value = [
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=None,
)
mock_vector_engine.search.assert_called_once()
assert graph.edges[0].attributes.get("vector_distance") == 0.88
@pytest.mark.asyncio
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
"""Test mapping edge distances when only some edges have results."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[1].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_edges_fallback_to_relationship_type(
setup_graph, mock_vector_engine
):
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"relationship_type": "KNOWS"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.85
@pytest.mark.asyncio
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
"""Test edge mapping when no edges match the distance results."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
"""Test that invalid query vector raises error."""
graph = setup_graph
with pytest.raises(ValueError, match="Failed to generate query embedding"):
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[],
edge_distances=None,
)
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances(setup_graph):
"""Test calculating top triplet importances by score."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
node4 = Node("4")
node1.add_attribute("vector_distance", 0.9)
node2.add_attribute("vector_distance", 0.8)
node3.add_attribute("vector_distance", 0.7)
node4.add_attribute("vector_distance", 0.6)
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
graph.add_node(node4)
edge1 = Edge(node1, node2)
edge2 = Edge(node2, node3)
edge3 = Edge(node3, node4)
edge1.add_attribute("vector_distance", 0.85)
edge2.add_attribute("vector_distance", 0.75)
edge3.add_attribute("vector_distance", 0.65)
graph.add_edge(edge1)
graph.add_edge(edge2)
graph.add_edge(edge3)
top_triplets = await graph.calculate_top_triplet_importances(k=2)
assert len(top_triplets) == 2
assert top_triplets[0] == edge3
assert top_triplets[1] == edge2
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
"""Test calculating importances when nodes/edges have no vector distances."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(node1, node2)
graph.add_edge(edge)
top_triplets = await graph.calculate_top_triplet_importances(k=1)
assert len(top_triplets) == 1
assert top_triplets[0] == edge

View file

@ -0,0 +1,582 @@
import pytest
from unittest.mock import AsyncMock, patch
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
get_memory_fragment,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_empty_query():
"""Test that empty query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query="")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query=None)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_negative_top_k():
"""Test that negative top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=-1)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_zero_top_k():
"""Test that zero top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=0)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_global_search():
"""Test that wide_search_limit is applied for global search (node_name=None)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=None, # Global search
wide_search_top_k=75,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 75
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=["Node1"],
wide_search_top_k=50,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_default():
"""Test that wide_search_top_k defaults to 100."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", node_name=None)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 100
@pytest.mark.asyncio
async def test_brute_force_triplet_search_default_collections():
"""Test that default collections are used when none provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test")
expected_collections = [
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
]
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == expected_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_custom_collections():
"""Test that custom collections are used when provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
custom_collections = ["CustomCol1", "CustomCol2"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=custom_collections)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == custom_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_all_collections_empty():
"""Test that empty list is returned when all collections return no results."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
results = await brute_force_triplet_search(query="test")
assert results == []
# Tests for query embedding
@pytest.mark.asyncio
async def test_brute_force_triplet_search_embeds_query():
"""Test that query is embedded before searching."""
query_text = "test query"
expected_vector = [0.1, 0.2, 0.3]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query=query_text)
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
for call in mock_vector_engine.search.call_args_list:
assert call[1]["query_vector"] == expected_vector
@pytest.mark.asyncio
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
"""Test that node IDs are extracted from search results for global search."""
scored_results = [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
MockScoredResult("node3", 0.92),
]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=scored_results)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_reuses_provided_fragment():
"""Test that provided memory fragment is reused instead of creating new one."""
provided_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
) as mock_get_fragment,
):
await brute_force_triplet_search(
query="test",
memory_fragment=provided_fragment,
node_name=["node"],
)
mock_get_fragment.assert_not_called()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
"""Test that memory fragment is created when not provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test", node_name=["node"])
mock_get_fragment.assert_called_once()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
"""Test that custom top_k is passed to importance calculation."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found")
)
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_error():
"""Test that get_memory_fragment returns empty graph on generic error."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_brute_force_triplet_search_deduplicates_node_ids():
"""Test that duplicate node IDs across collections are deduplicated."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
]
elif collection_name == "TextSummary_text":
return [
MockScoredResult("node1", 0.90),
MockScoredResult("node3", 0.92),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
@pytest.mark.asyncio
async def test_brute_force_triplet_search_excludes_edge_collection():
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "EdgeType_relationship_name":
return [MockScoredResult("edge1", 0.88)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(
query="test",
node_name=None,
collections=["Entity_name", "EdgeType_relationship_name"],
)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
@pytest.mark.asyncio
async def test_brute_force_triplet_search_skips_nodes_without_ids():
"""Test that nodes without ID attribute are skipped."""
class ScoredResultNoId:
"""Mock result without id attribute."""
def __init__(self, score):
self.score = score
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
ScoredResultNoId(0.90),
MockScoredResult("node2", 0.87),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_handles_tuple_results():
"""Test that both list and tuple results are handled correctly."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return (
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
)
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_mixed_empty_collections():
"""Test ID extraction with mixed empty and non-empty collections."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "TextSummary_text":
return []
elif collection_name == "EntityType_name":
return [MockScoredResult("node2", 0.92)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}