diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index d4e5fbbe6..354331c57 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -31,6 +31,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -200,6 +202,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return filtered_search_results diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile new file mode 100644 index 000000000..e83be3da4 --- /dev/null +++ b/cognee/eval_framework/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.11-slim + +# Set environment variables +ENV PIP_NO_CACHE_DIR=true +ENV PATH="${PATH}:/root/.poetry/bin" +ENV PYTHONPATH=/app +ENV SKIP_MIGRATIONS=true + +# System dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + libpq-dev \ + git \ + curl \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY pyproject.toml poetry.lock README.md /app/ + +RUN pip install poetry + +RUN poetry config virtualenvs.create false + +RUN poetry install --extras distributed --extras evals --extras deepeval --no-root + +COPY cognee/ /app/cognee +COPY distributed/ /app/distributed diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6f166657e..29b3ede68 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -35,6 +35,16 @@ class AnswerGeneratorExecutor: retrieval_context = await retriever.get_context(query_text) search_results = await retriever.get_completion(query_text, retrieval_context) + ############ + #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure.. + if isinstance(retrieval_context, list): + retrieval_context = await retriever.convert_retrieved_objects_to_context( + triplets=retrieval_context + ) + + if isinstance(search_results, str): + search_results = [search_results] + ############# answer = { "question": query_text, "answer": search_results[0], diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index d0a2ebe1e..6b55d84b2 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None + params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logger.info("Question answering started...") diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 6edcc0454..9e6f26688 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,7 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' + qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True @@ -25,7 +25,7 @@ class EvalConfig(BaseSettings): "EM", "f1", ] # Use only 'correctness' for DirectLLM - deepeval_model: str = "gpt-5-mini" + deepeval_model: str = "gpt-4o-mini" # Metrics params calculate_metrics: bool = True diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py index aca2686a5..bc2ff77c5 100644 --- a/cognee/eval_framework/modal_run_eval.py +++ b/cognee/eval_framework/modal_run_eval.py @@ -2,7 +2,6 @@ import modal import os import asyncio import datetime -import hashlib import json from cognee.shared.logging_utils import get_logger from cognee.eval_framework.eval_config import EvalConfig @@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b from cognee.eval_framework.answer_generation.run_question_answering_module import ( run_question_answering, ) +import pathlib +from os import path +from modal import Image from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation from cognee.eval_framework.metrics_dashboard import create_dashboard @@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict: app = modal.App("modal-run-eval") -image = ( - modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False) - .copy_local_file("pyproject.toml", "pyproject.toml") - .copy_local_file("poetry.lock", "poetry.lock") - .env( - { - "ENV": os.getenv("ENV"), - "LLM_API_KEY": os.getenv("LLM_API_KEY"), - "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), - } - ) - .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly") +image = Image.from_dockerfile( + path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(), + force_build=False, +).add_local_python_source("cognee") + + +@app.function( + image=image, + max_containers=10, + timeout=86400, + volumes={"/data": vol}, + secrets=[modal.Secret.from_name("eval_secrets")], ) - - -@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol}) async def modal_run_eval(eval_params=None): """Runs evaluation pipeline and returns combined metrics results.""" if eval_params is None: @@ -105,18 +104,7 @@ async def main(): configs = [ EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, - benchmark="HotPotQA", - qa_engine="cognee_graph_completion", - building_corpus_from_scratch=True, - answering_questions=True, - evaluating_answers=True, - calculate_metrics=True, - dashboard=True, - ), - EvalConfig( - task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="TwoWikiMultiHop", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, @@ -127,7 +115,7 @@ async def main(): ), EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="Musique", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 67df1a27c..8f8c96e79 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -398,3 +398,18 @@ class GraphDBInterface(ABC): - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError + + @abstractmethod + async def get_filtered_graph_data( + self, attribute_filters: List[Dict[str, List[Union[str, int]]]] + ) -> Tuple[List[Node], List[EdgeData]]: + """ + Retrieve nodes and edges filtered by the provided attribute criteria. + + Parameters: + ----------- + + - attribute_filters: A list of dictionaries where keys are attribute names and values + are lists of attribute values to filter by. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 8dd160665..9dbc9c1bc 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor from typing import Dict, Any, List, Union, Optional, Tuple, Type +from cognee.exceptions import CogneeValidationError from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.files.storage import get_file_storage @@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface): A tuple with two elements: a list of tuples of (node_id, properties) and a list of tuples of (source_id, target_id, relationship_name, properties). """ + + import time + + start_time = time.time() + try: nodes_query = """ MATCH (n:Node) @@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface): }, ) ) + + retrieval_time = time.time() - start_time + logger.info( + f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds" + ) return formatted_nodes, formatted_edges except Exception as e: logger.error(f"Failed to get graph data: {e}") @@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface): formatted_edges.append((source_id, target_id, rel_type, props)) return formatted_nodes, formatted_edges + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + Returns: + nodes: List[dict] -> Each dict includes "id" and all node properties + edges: List[dict] -> Each dict includes "source", "target", "type", "properties" + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + if not all(isinstance(x, str) for x in target_ids): + raise CogneeValidationError("target_ids must be a list of strings") + + query = """ + MATCH (n:Node)-[r]->(m:Node) + WHERE n.id IN $target_ids OR m.id IN $target_ids + RETURN n.id, { + name: n.name, + type: n.type, + properties: n.properties + }, m.id, { + name: m.name, + type: m.type, + properties: m.properties + }, r.relationship_name, r.properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + if not result: + logger.info("No data returned for the supplied IDs") + return [], [] + + nodes_dict = {} + edges = [] + + for n_id, n_props, m_id, m_props, r_type, r_props_raw in result: + if n_props.get("properties"): + try: + additional_props = json.loads(n_props["properties"]) + n_props.update(additional_props) + del n_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {n_id}") + + if m_props.get("properties"): + try: + additional_props = json.loads(m_props["properties"]) + m_props.update(additional_props) + del m_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {m_id}") + + nodes_dict[n_id] = (n_id, n_props) + nodes_dict[m_id] = (m_id, m_props) + + edge_props = {} + if r_props_raw: + try: + edge_props = json.loads(r_props_raw) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}") + + source_id = edge_props.get("source_node_id", n_id) + target_id = edge_props.get("target_node_id", m_id) + edges.append((source_id, target_id, r_type, edge_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 6216e107e..f3bb8e173 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface): logger.error(f"Error during graph data retrieval: {str(e)}") raise + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + This version uses a single Cypher query for efficiency. + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + query = """ + MATCH ()-[r]-() + WHERE startNode(r).id IN $target_ids + OR endNode(r).id IN $target_ids + WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b + RETURN + properties(a) AS n_properties, + properties(b) AS m_properties, + type(r) AS type, + properties(r) AS properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + nodes_dict = {} + edges = [] + + for record in result: + n_props = record["n_properties"] + m_props = record["m_properties"] + r_props = record["properties"] + r_type = record["type"] + + nodes_dict[n_props["id"]] = (n_props["id"], n_props) + nodes_dict[m_props["id"]] = (m_props["id"], m_props) + + source_id = r_props.get("source_node_id", n_props["id"]) + target_id = r_props.get("target_node_id", m_props["id"]) + edges.append((source_id, target_id, r_type, r_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index cb7562422..2e0b82e8d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + async def _get_nodeset_subgraph( + self, + adapter, + node_type, + node_name, + ): + """Retrieve subgraph based on node type and name.""" + logger.info("Retrieving graph filtered by node type and node name (NodeSet).") + nodes_data, edges_data = await adapter.get_nodeset_subgraph( + node_type=node_type, node_name=node_name + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Nodeset does not exist, or empty nodeset projected from the database." + ) + return nodes_data, edges_data + + async def _get_full_or_id_filtered_graph( + self, + adapter, + relevant_ids_to_filter, + ): + """Retrieve full or ID-filtered graph with fallback.""" + if relevant_ids_to_filter is None: + logger.info("Retrieving full graph.") + nodes_data, edges_data = await adapter.get_graph_data() + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty graph projected from the database.") + return nodes_data, edges_data + + get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data) + if getattr(adapter.__class__, "get_id_filtered_graph_data", None): + logger.info("Retrieving ID-filtered graph from database.") + nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter) + else: + logger.info("Retrieving full graph from database.") + nodes_data, edges_data = await get_graph_data_fn() + if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data): + logger.warning( + "Id filtered graph returned empty, falling back to full graph retrieval." + ) + logger.info("Retrieving full graph") + nodes_data, edges_data = await adapter.get_graph_data() + + if not nodes_data or not edges_data: + raise EntityNotFoundError("Empty graph projected from the database.") + return nodes_data, edges_data + + async def _get_filtered_graph( + self, + adapter, + memory_fragment_filter, + ): + """Retrieve graph filtered by attributes.""" + logger.info("Retrieving graph filtered by memory fragment") + nodes_data, edges_data = await adapter.get_filtered_graph_data( + attribute_filters=memory_fragment_filter + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty filtered graph projected from the database.") + return nodes_data, edges_data + async def project_graph_from_db( self, adapter: Union[GraphDBInterface], @@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph): memory_fragment_filter=[], node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: float = 3.5, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidDimensionsError() try: + if node_type is not None and node_name not in [None, [], ""]: + nodes_data, edges_data = await self._get_nodeset_subgraph( + adapter, node_type, node_name + ) + elif len(memory_fragment_filter) == 0: + nodes_data, edges_data = await self._get_full_or_id_filtered_graph( + adapter, relevant_ids_to_filter + ) + else: + nodes_data, edges_data = await self._get_filtered_graph( + adapter, memory_fragment_filter + ) + import time start_time = time.time() - - # Determine projection strategy - if node_type is not None and node_name not in [None, [], ""]: - nodes_data, edges_data = await adapter.get_nodeset_subgraph( - node_type=node_type, node_name=node_name - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Nodeset does not exist, or empty nodetes projected from the database." - ) - elif len(memory_fragment_filter) == 0: - nodes_data, edges_data = await adapter.get_graph_data() - if not nodes_data or not edges_data: - raise EntityNotFoundError(message="Empty graph projected from the database.") - else: - nodes_data, edges_data = await adapter.get_filtered_graph_data( - attribute_filters=memory_fragment_filter - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Empty filtered graph projected from the database." - ) - # Process nodes for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} - self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension)) + self.add_node( + Node( + str(node_id), + node_attributes, + dimension=node_dimension, + node_penalty=triplet_distance_penalty, + ) + ) # Process edges for source_id, target_id, relationship_type, properties in edges_data: @@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph): attributes=edge_attributes, directed=directed, dimension=edge_dimension, + edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 0ca9c4fb9..62ef8d9fd 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -20,13 +20,17 @@ class Node: status: np.ndarray def __init__( - self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1 + self, + node_id: str, + attributes: Optional[Dict[str, Any]] = None, + dimension: int = 1, + node_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = node_penalty self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -105,13 +109,14 @@ class Edge: attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1, + edge_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = edge_penalty self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index b07d11fd2..fc49a139b 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) async def get_completion( diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index eb8f502cb..70fcb6cdb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index df77a11ac..89e9e47ce 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction @@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): self.system_prompt_path = system_prompt_path self.system_prompt = system_prompt self.top_k = top_k if top_k is not None else 5 + self.wide_search_top_k = wide_search_top_k self.node_type = node_type self.node_name = node_name + self.triplet_distance_penalty = triplet_distance_penalty async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """ @@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): collections=vector_index_collections or None, node_type=self.node_type, node_name=self.node_name, + wide_search_top_k=self.wide_search_top_k, + triplet_distance_penalty=self.triplet_distance_penalty, ) return found_triplets @@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): return triplets + async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): + context = await self.resolve_edges_to_text(triplets) + return context + async def get_completion( self, query: str, diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 051f39b22..e31ad126e 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index f3da02c15..87d2ab009 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index f8bdbb97d..2f8a545f7 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -58,6 +58,8 @@ async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: Optional[float] = 3.5, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" if properties_to_project is None: @@ -74,6 +76,8 @@ async def get_memory_fragment( edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, ) except EntityNotFoundError: @@ -95,6 +99,8 @@ async def brute_force_triplet_search( memory_fragment: Optional[CogneeGraph] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Edge]: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -107,6 +113,8 @@ async def brute_force_triplet_search( memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: list: The top triplet results. @@ -116,10 +124,10 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project, node_type=node_type, node_name=node_name - ) + # Setting wide search limit based on the parameters + non_global_search = node_name is None + + wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: collections = [ @@ -140,7 +148,7 @@ async def brute_force_triplet_search( async def search_in_collection(collection_name: str): try: return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit ) except CollectionNotFoundError: return [] @@ -156,15 +164,38 @@ async def brute_force_triplet_search( return [] # Final statistics - projection_time = time.time() - start_time + vector_collection_search_time = time.time() - start_time logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s" + f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" ) node_distances = {collection: result for collection, result in zip(collections, results)} edge_distances = node_distances.get("EdgeType_relationship_name", None) + if wide_search_limit is not None: + relevant_ids_to_filter = list( + { + str(getattr(scored_node, "id")) + for collection_name, score_collection in node_distances.items() + if collection_name != "EdgeType_relationship_name" + and isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + ) + else: + relevant_ids_to_filter = None + + if memory_fragment is None: + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, + ) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) await memory_fragment.map_vector_distances_to_graph_edges( vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index 72e2db89a..165ec379b 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -37,6 +37,8 @@ async def get_search_type_tools( node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> list: search_tasks: dict[SearchType, List[Callable]] = { SearchType.SUMMARIES: [ @@ -67,6 +69,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -75,6 +79,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_COT: [ @@ -85,6 +91,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -93,6 +101,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [ @@ -103,6 +113,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -111,6 +123,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_SUMMARY_COMPLETION: [ @@ -121,6 +135,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -129,6 +145,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.CODE: [ @@ -145,8 +163,16 @@ async def get_search_type_tools( ], SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback], SearchType.TEMPORAL: [ - TemporalRetriever(top_k=top_k).get_completion, - TemporalRetriever(top_k=top_k).get_context, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_completion, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_context, ], SearchType.CHUNKS_LEXICAL: ( lambda _r=JaccardChunksRetriever(top_k=top_k): [ diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index fcb02da46..3a703bbc9 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -24,6 +24,8 @@ async def no_access_control_search( last_k: Optional[int] = None, only_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: search_tools = await get_search_type_tools( query_type=query_type, @@ -35,6 +37,8 @@ async def no_access_control_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) graph_engine = await get_graph_engine() is_empty = await graph_engine.is_empty() diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index b4278424b..9f180d607 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -47,6 +47,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -90,6 +92,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) else: search_results = [ @@ -105,6 +109,8 @@ async def search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ] @@ -219,6 +225,8 @@ async def authorized_search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -246,6 +254,8 @@ async def authorized_search( last_k=last_k, only_context=True, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) context = {} @@ -267,6 +277,8 @@ async def authorized_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -306,6 +318,7 @@ async def authorized_search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, ) return search_results @@ -325,6 +338,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -345,6 +360,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -378,6 +395,8 @@ async def search_in_datasets_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -413,6 +432,8 @@ async def search_in_datasets_context( only_context=only_context, context=context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 37ba113b5..1d2b79cf9 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": np.inf} + assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": np.inf, "weight": 10} + assert edge.attributes == {"vector_distance": 3.5, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 6888648c3..711479387 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph @@ -11,6 +12,30 @@ def setup_graph(): return CogneeGraph() +@pytest.fixture +def mock_adapter(): + """Fixture to create a mock adapter for database operations.""" + adapter = AsyncMock() + return adapter + + +@pytest.fixture +def mock_vector_engine(): + """Fixture to create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + def test_add_node_success(setup_graph): """Test successful addition of a node.""" graph = setup_graph @@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph): graph = setup_graph with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): graph.get_edges_from_node("nonexistent") + + +@pytest.mark.asyncio +async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter): + """Test projecting a full graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1", "description": "First node"}), + ("2", {"name": "Node2", "description": "Second node"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "description"], + edge_properties_to_project=["relationship_name"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + assert graph.get_node("1") is not None + assert graph.get_node("2") is not None + assert graph.edges[0].node1.id == "1" + assert graph.edges[0].node2.id == "2" + + +@pytest.mark.asyncio +async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter): + """Test projecting an ID-filtered graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ("2", {"name": "Node2"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + relevant_ids_to_filter=["1", "2"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + mock_adapter.get_id_filtered_graph_data.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter): + """Test projecting a nodeset subgraph filtered by node type and name.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Alice", "type": "Person"}), + ("2", {"name": "Bob", "type": "Person"}), + ] + edges_data = [ + ("1", "2", "KNOWS", {"relationship_name": "knows"}), + ] + + mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "type"], + edge_properties_to_project=["relationship_name"], + node_type="Person", + node_name=["Alice"], + ) + + assert len(graph.nodes) == 2 + assert graph.get_node("1") is not None + assert len(graph.edges) == 1 + mock_adapter.get_nodeset_subgraph.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): + """Test projecting empty graph raises EntityNotFoundError.""" + graph = setup_graph + + mock_adapter.get_graph_data = AsyncMock(return_value=([], [])) + + with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + +@pytest.mark.asyncio +async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): + """Test that edges referencing missing nodes raise error.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ] + edges_data = [ + ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + ) + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes(setup_graph): + """Test mapping vector distances to graph nodes.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_node_coverage(setup_graph): + """Test mapping vector distances when only some nodes have results.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + node3 = Node("3", {"name": "Node3"}) + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_multiple_categories(setup_graph): + """Test mapping vector distances from multiple collection categories.""" + graph = setup_graph + + # Create nodes + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ], + "TextSummary_text": [ + MockScoredResult("3", 0.92), + ], + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 0.92 + assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine): + """Test mapping vector distances to edges when edge_distances provided.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine): + """Test mapping edge distances when searching for them.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + mock_vector_engine.search.return_value = [ + MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=None, + ) + + mock_vector_engine.search.assert_called_once() + assert graph.edges[0].attributes.get("vector_distance") == 0.88 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine): + """Test mapping edge distances when only some edges have results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[1].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_edges_fallback_to_relationship_type( + setup_graph, mock_vector_engine +): + """Test that edge mapping falls back to relationship_type when edge_text is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"relationship_type": "KNOWS"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.85 + + +@pytest.mark.asyncio +async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine): + """Test edge mapping when no edges match the distance results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine): + """Test that invalid query vector raises error.""" + graph = setup_graph + + with pytest.raises(ValueError, match="Failed to generate query embedding"): + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[], + edge_distances=None, + ) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances(setup_graph): + """Test calculating top triplet importances by score.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + + node1.add_attribute("vector_distance", 0.9) + node2.add_attribute("vector_distance", 0.8) + node3.add_attribute("vector_distance", 0.7) + node4.add_attribute("vector_distance", 0.6) + + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + edge3 = Edge(node3, node4) + + edge1.add_attribute("vector_distance", 0.85) + edge2.add_attribute("vector_distance", 0.75) + edge3.add_attribute("vector_distance", 0.65) + + graph.add_edge(edge1) + graph.add_edge(edge2) + graph.add_edge(edge3) + + top_triplets = await graph.calculate_top_triplet_importances(k=2) + + assert len(top_triplets) == 2 + + assert top_triplets[0] == edge3 + assert top_triplets[1] == edge2 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_default_distances(setup_graph): + """Test calculating importances when nodes/edges have no vector distances.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2) + graph.add_edge(edge) + + top_triplets = await graph.calculate_top_triplet_importances(k=1) + + assert len(top_triplets) == 1 + assert top_triplets[0] == edge diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..5eb6fb105 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,582 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == custom_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}