Merge branch 'dev' into issue-1767
This commit is contained in:
commit
57195fb5a1
23 changed files with 1482 additions and 70 deletions
|
|
@ -31,6 +31,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -200,6 +202,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
29
cognee/eval_framework/Dockerfile
Normal file
29
cognee/eval_framework/Dockerfile
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml poetry.lock README.md /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
|
|
@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
|
|||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
############
|
||||
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
|
||||
if isinstance(retrieval_context, list):
|
||||
retrieval_context = await retriever.convert_retrieved_objects_to_context(
|
||||
triplets=retrieval_context
|
||||
)
|
||||
|
||||
if isinstance(search_results, str):
|
||||
search_results = [search_results]
|
||||
#############
|
||||
answer = {
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logger.info("Question answering started...")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
|
|
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
|
|||
"EM",
|
||||
"f1",
|
||||
] # Use only 'correctness' for DirectLLM
|
||||
deepeval_model: str = "gpt-5-mini"
|
||||
deepeval_model: str = "gpt-4o-mini"
|
||||
|
||||
# Metrics params
|
||||
calculate_metrics: bool = True
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import modal
|
|||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
|
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
|
|||
from cognee.eval_framework.answer_generation.run_question_answering_module import (
|
||||
run_question_answering,
|
||||
)
|
||||
import pathlib
|
||||
from os import path
|
||||
from modal import Image
|
||||
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
|
||||
from cognee.eval_framework.metrics_dashboard import create_dashboard
|
||||
|
||||
|
|
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
|
|||
|
||||
app = modal.App("modal-run-eval")
|
||||
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.copy_local_file("pyproject.toml", "pyproject.toml")
|
||||
.copy_local_file("poetry.lock", "poetry.lock")
|
||||
.env(
|
||||
{
|
||||
"ENV": os.getenv("ENV"),
|
||||
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
)
|
||||
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
|
||||
image = Image.from_dockerfile(
|
||||
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
|
||||
force_build=False,
|
||||
).add_local_python_source("cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
max_containers=10,
|
||||
timeout=86400,
|
||||
volumes={"/data": vol},
|
||||
secrets=[modal.Secret.from_name("eval_secrets")],
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
|
||||
async def modal_run_eval(eval_params=None):
|
||||
"""Runs evaluation pipeline and returns combined metrics results."""
|
||||
if eval_params is None:
|
||||
|
|
@ -105,18 +104,7 @@ async def main():
|
|||
configs = [
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="HotPotQA",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
answering_questions=True,
|
||||
evaluating_answers=True,
|
||||
calculate_metrics=True,
|
||||
dashboard=True,
|
||||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="TwoWikiMultiHop",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
@ -127,7 +115,7 @@ async def main():
|
|||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="Musique",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
|
|||
|
|
@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve nodes and edges filtered by the provided attribute criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
||||
are lists of attribute values to filter by.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
|
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
||||
tuples of (source_id, target_id, relationship_name, properties).
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
|
|
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
|
|||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
return formatted_nodes, formatted_edges
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
Returns:
|
||||
nodes: List[dict] -> Each dict includes "id" and all node properties
|
||||
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
if not all(isinstance(x, str) for x in target_ids):
|
||||
raise CogneeValidationError("target_ids must be a list of strings")
|
||||
|
||||
query = """
|
||||
MATCH (n:Node)-[r]->(m:Node)
|
||||
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
||||
RETURN n.id, {
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}, m.id, {
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}, r.relationship_name, r.properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
if not result:
|
||||
logger.info("No data returned for the supplied IDs")
|
||||
return [], []
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
||||
if n_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(n_props["properties"])
|
||||
n_props.update(additional_props)
|
||||
del n_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
||||
|
||||
if m_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(m_props["properties"])
|
||||
m_props.update(additional_props)
|
||||
del m_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
||||
|
||||
nodes_dict[n_id] = (n_id, n_props)
|
||||
nodes_dict[m_id] = (m_id, m_props)
|
||||
|
||||
edge_props = {}
|
||||
if r_props_raw:
|
||||
try:
|
||||
edge_props = json.loads(r_props_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
||||
|
||||
source_id = edge_props.get("source_node_id", n_id)
|
||||
target_id = edge_props.get("target_node_id", m_id)
|
||||
edges.append((source_id, target_id, r_type, edge_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
|
|
|||
|
|
@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
This version uses a single Cypher query for efficiency.
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
query = """
|
||||
MATCH ()-[r]-()
|
||||
WHERE startNode(r).id IN $target_ids
|
||||
OR endNode(r).id IN $target_ids
|
||||
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
|
||||
RETURN
|
||||
properties(a) AS n_properties,
|
||||
properties(b) AS m_properties,
|
||||
type(r) AS type,
|
||||
properties(r) AS properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for record in result:
|
||||
n_props = record["n_properties"]
|
||||
m_props = record["m_properties"]
|
||||
r_props = record["properties"]
|
||||
r_type = record["type"]
|
||||
|
||||
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
|
||||
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
|
||||
|
||||
source_id = r_props.get("source_node_id", n_props["id"])
|
||||
target_id = r_props.get("target_node_id", m_props["id"])
|
||||
edges.append((source_id, target_id, r_type, r_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
node_type,
|
||||
node_name,
|
||||
):
|
||||
"""Retrieve subgraph based on node type and name."""
|
||||
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodeset projected from the database."
|
||||
)
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_full_or_id_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
relevant_ids_to_filter,
|
||||
):
|
||||
"""Retrieve full or ID-filtered graph with fallback."""
|
||||
if relevant_ids_to_filter is None:
|
||||
logger.info("Retrieving full graph.")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
||||
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
||||
logger.info("Retrieving ID-filtered graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
||||
else:
|
||||
logger.info("Retrieving full graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn()
|
||||
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
||||
logger.warning(
|
||||
"Id filtered graph returned empty, falling back to full graph retrieval."
|
||||
)
|
||||
logger.info("Retrieving full graph")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError("Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
memory_fragment_filter,
|
||||
):
|
||||
"""Retrieve graph filtered by attributes."""
|
||||
logger.info("Retrieving graph filtered by memory fragment")
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def project_graph_from_db(
|
||||
self,
|
||||
adapter: Union[GraphDBInterface],
|
||||
|
|
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: float = 3.5,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidDimensionsError()
|
||||
try:
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await self._get_nodeset_subgraph(
|
||||
adapter, node_type, node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
|
||||
adapter, relevant_ids_to_filter
|
||||
)
|
||||
else:
|
||||
nodes_data, edges_data = await self._get_filtered_graph(
|
||||
adapter, memory_fragment_filter
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodetes projected from the database."
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
self.add_node(
|
||||
Node(
|
||||
str(node_id),
|
||||
node_attributes,
|
||||
dimension=node_dimension,
|
||||
node_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
|
|
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
attributes=edge_attributes,
|
||||
directed=directed,
|
||||
dimension=edge_dimension,
|
||||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,17 @@ class Node:
|
|||
status: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
|
||||
self,
|
||||
node_id: str,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
dimension: int = 1,
|
||||
node_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = node_penalty
|
||||
self.skeleton_neighbours = []
|
||||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
|
@ -105,13 +109,14 @@ class Edge:
|
|||
attributes: Optional[Dict[str, Any]] = None,
|
||||
directed: bool = True,
|
||||
dimension: int = 1,
|
||||
edge_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = edge_penalty
|
||||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
|
|
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.wide_search_top_k = wide_search_top_k
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""
|
||||
|
|
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
wide_search_top_k=self.wide_search_top_k,
|
||||
triplet_distance_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
return context
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
|
|
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ async def get_memory_fragment(
|
|||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
if properties_to_project is None:
|
||||
|
|
@ -74,6 +76,8 @@ async def get_memory_fragment(
|
|||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
|
|
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
|
|||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
|
|||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
|
|||
return []
|
||||
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ async def get_search_type_tools(
|
|||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, List[Callable]] = {
|
||||
SearchType.SUMMARIES: [
|
||||
|
|
@ -67,6 +69,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -75,6 +79,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_COT: [
|
||||
|
|
@ -85,6 +91,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -93,6 +101,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
||||
|
|
@ -103,6 +113,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -111,6 +123,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
||||
|
|
@ -121,6 +135,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -129,6 +145,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
|
|
@ -145,8 +163,16 @@ async def get_search_type_tools(
|
|||
],
|
||||
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
||||
SearchType.TEMPORAL: [
|
||||
TemporalRetriever(top_k=top_k).get_completion,
|
||||
TemporalRetriever(top_k=top_k).get_context,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CHUNKS_LEXICAL: (
|
||||
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ async def no_access_control_search(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
|
|
@ -35,6 +37,8 @@ async def no_access_control_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
"""
|
||||
|
||||
|
|
@ -90,6 +92,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
else:
|
||||
search_results = [
|
||||
|
|
@ -105,6 +109,8 @@ async def search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -219,6 +225,8 @@ async def authorized_search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
|
|
@ -246,6 +254,8 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
|
@ -267,6 +277,8 @@ async def authorized_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -306,6 +318,7 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
@ -325,6 +338,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
|
|
@ -345,6 +360,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
|
@ -378,6 +395,8 @@ async def search_in_datasets_context(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -413,6 +432,8 @@ async def search_in_datasets_context(
|
|||
only_context=only_context,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_node_initialization():
|
|||
"""Test that a Node is initialized correctly."""
|
||||
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
||||
assert node.id == "node1"
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
|
||||
assert len(node.status) == 2
|
||||
assert np.all(node.status == 1)
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ def test_edge_initialization():
|
|||
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
||||
assert edge.node1 == node1
|
||||
assert edge.node2 == node2
|
||||
assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
|
||||
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
|
||||
assert edge.directed is False
|
||||
assert len(edge.status) == 2
|
||||
assert np.all(edge.status == 1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
|
@ -11,6 +12,30 @@ def setup_graph():
|
|||
return CogneeGraph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""Fixture to create a mock adapter for database operations."""
|
||||
adapter = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Fixture to create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
def test_add_node_success(setup_graph):
|
||||
"""Test successful addition of a node."""
|
||||
graph = setup_graph
|
||||
|
|
@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
|
|||
graph = setup_graph
|
||||
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||
graph.get_edges_from_node("nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting a full graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1", "description": "First node"}),
|
||||
("2", {"name": "Node2", "description": "Second node"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "description"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.get_node("1") is not None
|
||||
assert graph.get_node("2") is not None
|
||||
assert graph.edges[0].node1.id == "1"
|
||||
assert graph.edges[0].node2.id == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
|
||||
"""Test projecting an ID-filtered graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
("2", {"name": "Node2"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
relevant_ids_to_filter=["1", "2"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_id_filtered_graph_data.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
|
||||
"""Test projecting a nodeset subgraph filtered by node type and name."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Alice", "type": "Person"}),
|
||||
("2", {"name": "Bob", "type": "Person"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "KNOWS", {"relationship_name": "knows"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "type"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type="Person",
|
||||
node_name=["Alice"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert graph.get_node("1") is not None
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_nodeset_subgraph.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting empty graph raises EntityNotFoundError."""
|
||||
graph = setup_graph
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
||||
"""Test that edges referencing missing nodes raise error."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
||||
"""Test mapping vector distances to graph nodes."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
||||
"""Test mapping vector distances when only some nodes have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
node3 = Node("3", {"name": "Node3"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_multiple_categories(setup_graph):
|
||||
"""Test mapping vector distances from multiple collection categories."""
|
||||
graph = setup_graph
|
||||
|
||||
# Create nodes
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
],
|
||||
"TextSummary_text": [
|
||||
MockScoredResult("3", 0.92),
|
||||
],
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"relationship_type": "KNOWS"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
graph = setup_graph
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances(setup_graph):
|
||||
"""Test calculating top triplet importances by score."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
|
||||
node1.add_attribute("vector_distance", 0.9)
|
||||
node2.add_attribute("vector_distance", 0.8)
|
||||
node3.add_attribute("vector_distance", 0.7)
|
||||
node4.add_attribute("vector_distance", 0.6)
|
||||
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
edge1 = Edge(node1, node2)
|
||||
edge2 = Edge(node2, node3)
|
||||
edge3 = Edge(node3, node4)
|
||||
|
||||
edge1.add_attribute("vector_distance", 0.85)
|
||||
edge2.add_attribute("vector_distance", 0.75)
|
||||
edge3.add_attribute("vector_distance", 0.65)
|
||||
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
graph.add_edge(edge3)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=2)
|
||||
|
||||
assert len(top_triplets) == 2
|
||||
|
||||
assert top_triplets[0] == edge3
|
||||
assert top_triplets[1] == edge2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have no vector distances."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
|
|
|
|||
|
|
@ -0,0 +1,582 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_empty_query():
|
||||
"""Test that empty query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query="")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_negative_top_k():
|
||||
"""Test that negative top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=-1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_zero_top_k():
|
||||
"""Test that zero top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
||||
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None, # Global search
|
||||
wide_search_top_k=75,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 75
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
||||
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=["Node1"],
|
||||
wide_search_top_k=50,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_default():
|
||||
"""Test that wide_search_top_k defaults to 100."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_default_collections():
|
||||
"""Test that default collections are used when none provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
||||
expected_collections = [
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == expected_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_custom_collections():
|
||||
"""Test that custom collections are used when provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_all_collections_empty():
|
||||
"""Test that empty list is returned when all collections return no results."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
assert results == []
|
||||
|
||||
|
||||
# Tests for query embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_embeds_query():
|
||||
"""Test that query is embedded before searching."""
|
||||
query_text = "test query"
|
||||
expected_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["query_vector"] == expected_vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
||||
"""Test that node IDs are extracted from search results for global search."""
|
||||
scored_results = [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
||||
"""Test that provided memory fragment is reused instead of creating new one."""
|
||||
provided_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
memory_fragment=provided_fragment,
|
||||
node_name=["node"],
|
||||
)
|
||||
|
||||
mock_get_fragment.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
||||
"""Test that memory fragment is created when not provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=["node"])
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
||||
"""Test that custom top_k is passed to importance calculation."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
||||
"""Test that get_memory_fragment returns empty graph on generic error."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
||||
"""Test that duplicate node IDs across collections are deduplicated."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return [
|
||||
MockScoredResult("node1", 0.90),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_excludes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [MockScoredResult("edge1", 0.88)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None,
|
||||
collections=["Entity_name", "EdgeType_relationship_name"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
||||
"""Test that nodes without ID attribute are skipped."""
|
||||
|
||||
class ScoredResultNoId:
|
||||
"""Mock result without id attribute."""
|
||||
|
||||
def __init__(self, score):
|
||||
self.score = score
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
ScoredResultNoId(0.90),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_handles_tuple_results():
|
||||
"""Test that both list and tuple results are handled correctly."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return (
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_mixed_empty_collections():
|
||||
"""Test ID extraction with mixed empty and non-empty collections."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return []
|
||||
elif collection_name == "EntityType_name":
|
||||
return [MockScoredResult("node2", 0.92)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
Loading…
Add table
Reference in a new issue