diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py
index d4e5fbbe6..354331c57 100644
--- a/cognee/api/v1/search/search.py
+++ b/cognee/api/v1/search/search.py
@@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results
diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile
new file mode 100644
index 000000000..e83be3da4
--- /dev/null
+++ b/cognee/eval_framework/Dockerfile
@@ -0,0 +1,29 @@
+FROM python:3.11-slim
+
+# Set environment variables
+ENV PIP_NO_CACHE_DIR=true
+ENV PATH="${PATH}:/root/.poetry/bin"
+ENV PYTHONPATH=/app
+ENV SKIP_MIGRATIONS=true
+
+# System dependencies
+RUN apt-get update && apt-get install -y \
+ gcc \
+ libpq-dev \
+ git \
+ curl \
+ build-essential \
+ && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /app
+
+COPY pyproject.toml poetry.lock README.md /app/
+
+RUN pip install poetry
+
+RUN poetry config virtualenvs.create false
+
+RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
+
+COPY cognee/ /app/cognee
+COPY distributed/ /app/distributed
diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 6f166657e..29b3ede68 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
+ ############
+ #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
+ if isinstance(retrieval_context, list):
+ retrieval_context = await retriever.convert_retrieved_objects_to_context(
+ triplets=retrieval_context
+ )
+
+ if isinstance(search_results, str):
+ search_results = [search_results]
+ #############
answer = {
"question": query_text,
"answer": search_results[0],
diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py
index d0a2ebe1e..6b55d84b2 100644
--- a/cognee/eval_framework/answer_generation/run_question_answering_module.py
+++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py
@@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
- params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
+ params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")
diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py
index 6edcc0454..9e6f26688 100644
--- a/cognee/eval_framework/eval_config.py
+++ b/cognee/eval_framework/eval_config.py
@@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
- qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
+ qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
- deepeval_model: str = "gpt-5-mini"
+ deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True
diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py
index aca2686a5..bc2ff77c5 100644
--- a/cognee/eval_framework/modal_run_eval.py
+++ b/cognee/eval_framework/modal_run_eval.py
@@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
-import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
+import pathlib
+from os import path
+from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
-image = (
- modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
- .copy_local_file("pyproject.toml", "pyproject.toml")
- .copy_local_file("poetry.lock", "poetry.lock")
- .env(
- {
- "ENV": os.getenv("ENV"),
- "LLM_API_KEY": os.getenv("LLM_API_KEY"),
- "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
- }
- )
- .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
+image = Image.from_dockerfile(
+ path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
+ force_build=False,
+).add_local_python_source("cognee")
+
+
+@app.function(
+ image=image,
+ max_containers=10,
+ timeout=86400,
+ volumes={"/data": vol},
+ secrets=[modal.Secret.from_name("eval_secrets")],
)
-
-
-@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
- benchmark="HotPotQA",
- qa_engine="cognee_graph_completion",
- building_corpus_from_scratch=True,
- answering_questions=True,
- evaluating_answers=True,
- calculate_metrics=True,
- dashboard=True,
- ),
- EvalConfig(
- task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py
index 67df1a27c..8f8c96e79 100644
--- a/cognee/infrastructure/databases/graph/graph_db_interface.py
+++ b/cognee/infrastructure/databases/graph/graph_db_interface.py
@@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
+
+ @abstractmethod
+ async def get_filtered_graph_data(
+ self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
+ ) -> Tuple[List[Node], List[EdgeData]]:
+ """
+ Retrieve nodes and edges filtered by the provided attribute criteria.
+
+ Parameters:
+ -----------
+
+ - attribute_filters: A list of dictionaries where keys are attribute names and values
+ are lists of attribute values to filter by.
+ """
+ raise NotImplementedError
diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py
index 8dd160665..9dbc9c1bc 100644
--- a/cognee/infrastructure/databases/graph/kuzu/adapter.py
+++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
+from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
+
+ import time
+
+ start_time = time.time()
+
try:
nodes_query = """
MATCH (n:Node)
@@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
+ )
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ Returns:
+ nodes: List[dict] -> Each dict includes "id" and all node properties
+ edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ if not all(isinstance(x, str) for x in target_ids):
+ raise CogneeValidationError("target_ids must be a list of strings")
+
+ query = """
+ MATCH (n:Node)-[r]->(m:Node)
+ WHERE n.id IN $target_ids OR m.id IN $target_ids
+ RETURN n.id, {
+ name: n.name,
+ type: n.type,
+ properties: n.properties
+ }, m.id, {
+ name: m.name,
+ type: m.type,
+ properties: m.properties
+ }, r.relationship_name, r.properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ if not result:
+ logger.info("No data returned for the supplied IDs")
+ return [], []
+
+ nodes_dict = {}
+ edges = []
+
+ for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
+ if n_props.get("properties"):
+ try:
+ additional_props = json.loads(n_props["properties"])
+ n_props.update(additional_props)
+ del n_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {n_id}")
+
+ if m_props.get("properties"):
+ try:
+ additional_props = json.loads(m_props["properties"])
+ m_props.update(additional_props)
+ del m_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {m_id}")
+
+ nodes_dict[n_id] = (n_id, n_props)
+ nodes_dict[m_id] = (m_id, m_props)
+
+ edge_props = {}
+ if r_props_raw:
+ try:
+ edge_props = json.loads(r_props_raw)
+ except (json.JSONDecodeError, TypeError):
+ logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
+
+ source_id = edge_props.get("source_node_id", n_id)
+ target_id = edge_props.get("target_node_id", m_id)
+ edges.append((source_id, target_id, r_type, edge_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
index 6216e107e..f3bb8e173 100644
--- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
+++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
@@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ This version uses a single Cypher query for efficiency.
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ query = """
+ MATCH ()-[r]-()
+ WHERE startNode(r).id IN $target_ids
+ OR endNode(r).id IN $target_ids
+ WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
+ RETURN
+ properties(a) AS n_properties,
+ properties(b) AS m_properties,
+ type(r) AS type,
+ properties(r) AS properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ nodes_dict = {}
+ edges = []
+
+ for record in result:
+ n_props = record["n_properties"]
+ m_props = record["m_properties"]
+ r_props = record["properties"]
+ r_type = record["type"]
+
+ nodes_dict[n_props["id"]] = (n_props["id"], n_props)
+ nodes_dict[m_props["id"]] = (m_props["id"], m_props)
+
+ source_id = r_props.get("source_node_id", n_props["id"])
+ target_id = r_props.get("target_node_id", m_props["id"])
+ edges.append((source_id, target_id, r_type, r_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py
index cb7562422..2e0b82e8d 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraph.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py
@@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
+ async def _get_nodeset_subgraph(
+ self,
+ adapter,
+ node_type,
+ node_name,
+ ):
+ """Retrieve subgraph based on node type and name."""
+ logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
+ nodes_data, edges_data = await adapter.get_nodeset_subgraph(
+ node_type=node_type, node_name=node_name
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(
+ message="Nodeset does not exist, or empty nodeset projected from the database."
+ )
+ return nodes_data, edges_data
+
+ async def _get_full_or_id_filtered_graph(
+ self,
+ adapter,
+ relevant_ids_to_filter,
+ ):
+ """Retrieve full or ID-filtered graph with fallback."""
+ if relevant_ids_to_filter is None:
+ logger.info("Retrieving full graph.")
+ nodes_data, edges_data = await adapter.get_graph_data()
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
+ if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
+ logger.info("Retrieving ID-filtered graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
+ else:
+ logger.info("Retrieving full graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn()
+ if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
+ logger.warning(
+ "Id filtered graph returned empty, falling back to full graph retrieval."
+ )
+ logger.info("Retrieving full graph")
+ nodes_data, edges_data = await adapter.get_graph_data()
+
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError("Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ async def _get_filtered_graph(
+ self,
+ adapter,
+ memory_fragment_filter,
+ ):
+ """Retrieve graph filtered by attributes."""
+ logger.info("Retrieving graph filtered by memory fragment")
+ nodes_data, edges_data = await adapter.get_filtered_graph_data(
+ attribute_filters=memory_fragment_filter
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
+ return nodes_data, edges_data
+
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
+ if node_type is not None and node_name not in [None, [], ""]:
+ nodes_data, edges_data = await self._get_nodeset_subgraph(
+ adapter, node_type, node_name
+ )
+ elif len(memory_fragment_filter) == 0:
+ nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
+ adapter, relevant_ids_to_filter
+ )
+ else:
+ nodes_data, edges_data = await self._get_filtered_graph(
+ adapter, memory_fragment_filter
+ )
+
import time
start_time = time.time()
-
- # Determine projection strategy
- if node_type is not None and node_name not in [None, [], ""]:
- nodes_data, edges_data = await adapter.get_nodeset_subgraph(
- node_type=node_type, node_name=node_name
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Nodeset does not exist, or empty nodetes projected from the database."
- )
- elif len(memory_fragment_filter) == 0:
- nodes_data, edges_data = await adapter.get_graph_data()
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(message="Empty graph projected from the database.")
- else:
- nodes_data, edges_data = await adapter.get_filtered_graph_data(
- attribute_filters=memory_fragment_filter
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Empty filtered graph projected from the database."
- )
-
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
- self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
+ self.add_node(
+ Node(
+ str(node_id),
+ node_attributes,
+ dimension=node_dimension,
+ node_penalty=triplet_distance_penalty,
+ )
+ )
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
+ edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
index 0ca9c4fb9..62ef8d9fd 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
@@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
- self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
+ self,
+ node_id: str,
+ attributes: Optional[Dict[str, Any]] = None,
+ dimension: int = 1,
+ node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
+ edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)
diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
index b07d11fd2..fc49a139b 100644
--- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
@@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
index eb8f502cb..70fcb6cdb 100644
--- a/cognee/modules/retrieval/graph_completion_cot_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py
index d26091eee..317d7cd9a 100644
--- a/cognee/modules/retrieval/graph_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_retriever.py
@@ -49,6 +49,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@@ -56,8 +58,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
+ self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
+ self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@@ -107,6 +111,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
+ wide_search_top_k=self.wide_search_top_k,
+ triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@@ -146,6 +152,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
await update_node_access_timestamps(entity_nodes)
return triplets
+ async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
+ context = await self.resolve_edges_to_text(triplets)
+ return context
+
async def get_completion(
self,
query: str,
diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py
index 051f39b22..e31ad126e 100644
--- a/cognee/modules/retrieval/graph_summary_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py
@@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path
diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py
index f3da02c15..87d2ab009 100644
--- a/cognee/modules/retrieval/temporal_retriever.py
+++ b/cognee/modules/retrieval/temporal_retriever.py
@@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
index f8bdbb97d..2f8a545f7 100644
--- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py
+++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
@@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@@ -74,6 +76,8 @@ async def get_memory_fragment(
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
+ triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
- if memory_fragment is None:
- memory_fragment = await get_memory_fragment(
- properties_to_project, node_type=node_type, node_name=node_name
- )
+ # Setting wide search limit based on the parameters
+ non_global_search = node_name is None
+
+ wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
- collection_name=collection_name, query_vector=query_vector, limit=None
+ collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
- projection_time = time.time() - start_time
+ vector_collection_search_time = time.time() - start_time
logger.info(
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
+ if wide_search_limit is not None:
+ relevant_ids_to_filter = list(
+ {
+ str(getattr(scored_node, "id"))
+ for collection_name, score_collection in node_distances.items()
+ if collection_name != "EdgeType_relationship_name"
+ and isinstance(score_collection, (list, tuple))
+ for scored_node in score_collection
+ if getattr(scored_node, "id", None)
+ }
+ )
+ else:
+ relevant_ids_to_filter = None
+
+ if memory_fragment is None:
+ memory_fragment = await get_memory_fragment(
+ properties_to_project=properties_to_project,
+ node_type=node_type,
+ node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
+ )
+
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py
index 72e2db89a..165ec379b 100644
--- a/cognee/modules/search/methods/get_search_type_tools.py
+++ b/cognee/modules/search/methods/get_search_type_tools.py
@@ -37,6 +37,8 @@ async def get_search_type_tools(
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
@@ -67,6 +69,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -75,6 +79,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
@@ -85,6 +91,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@@ -93,6 +101,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@@ -103,6 +113,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@@ -111,6 +123,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
@@ -121,6 +135,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -129,6 +145,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
@@ -145,8 +163,16 @@ async def get_search_type_tools(
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
- TemporalRetriever(top_k=top_k).get_completion,
- TemporalRetriever(top_k=top_k).get_context,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_completion,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_context,
],
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [
diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py
index fcb02da46..3a703bbc9 100644
--- a/cognee/modules/search/methods/no_access_control_search.py
+++ b/cognee/modules/search/methods/no_access_control_search.py
@@ -24,6 +24,8 @@ async def no_access_control_search(
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@@ -35,6 +37,8 @@ async def no_access_control_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index b4278424b..9f180d607 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -47,6 +47,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@@ -90,6 +92,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
else:
search_results = [
@@ -105,6 +109,8 @@ async def search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
]
@@ -219,6 +225,8 @@ async def authorized_search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@@ -246,6 +254,8 @@ async def authorized_search(
last_k=last_k,
only_context=True,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
@@ -267,6 +277,8 @@ async def authorized_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -306,6 +318,7 @@ async def authorized_search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
)
return search_results
@@ -325,6 +338,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@@ -345,6 +360,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@@ -378,6 +395,8 @@ async def search_in_datasets_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -413,6 +432,8 @@ async def search_in_datasets_context(
only_context=only_context,
context=context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
)
diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py
index d53c5ab44..af3a4d90e 100644
--- a/cognee/tests/unit/api/test_ontology_endpoint.py
+++ b/cognee/tests/unit/api/test_ontology_endpoint.py
@@ -25,7 +25,10 @@ def mock_user():
def mock_default_user():
"""Mock default user for testing."""
return SimpleNamespace(
- id=uuid.uuid4(), email="default@example.com", is_active=True, tenant_id=uuid.uuid4()
+ id=str(uuid.uuid4()),
+ email="default@example.com",
+ is_active=True,
+ tenant_id=str(uuid.uuid4()),
)
@@ -108,6 +111,7 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_
"""Test uploading multiple ontology files in single request"""
import io
+ mock_get_default_user.return_value = mock_default_user
# Create mock files
file1_content = b""
file2_content = b""
@@ -137,6 +141,7 @@ def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_defa
import io
import json
+ mock_get_default_user.return_value = mock_default_user
file_content = b""
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
@@ -173,6 +178,7 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default
import io
import json
+ mock_get_default_user.return_value = mock_default_user
# Step 1: Upload multiple ontologies
file1_content = b"""