feature: Introduces wide subgraph search in graph completion and improves QA speed (#1736)
<!-- .github/pull_request_template.md --> This PR introduces wide vector and graph structure filtering capabilities. With these changes, the graph completion retriever and all retrievers that inherit from it will now filter relevant vector elements and subgraphs based on the query. This improvement significantly increases search speed for large graphs while maintaining—and in some cases slightly improving—accuracy. Changes in This PR: -Introduced new wide_search_top_k parameter: Controls the initial search space size -Added graph adapter level filtering method: Enables relevant subgraph filtering while maintaining backward compatibility. For community or custom graph adapters that don't implement this method, the system gracefully falls back to the original search behavior. -Updated modal dashboard and evaluation framework: Fixed compatibility issues. Added comprehensive unit tests: Introduced unit tests for brute_force_triplet_search (previously untested) and expanded the CogneeGraph test suite. Integration tests: Existing integration tests verify end-to-end search functionality (no changes required). Acceptance Criteria and Testing To verify the new search behavior, run search queries with different wide_search_top_k parameters while logging is enabled: None: Triggers a full graph search (default behavior) 1: Projects a minimal subgraph (demonstrates maximum filtering) Custom values: Test intermediate levels of filtering Internal Testing and results: Performance and accuracy benchmarks are available upon request. The implementation demonstrates measurable improvements in query latency for large graphs without sacrificing result quality. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
This commit is contained in:
parent
c2c64a417c
commit
508165e883
23 changed files with 1482 additions and 70 deletions
|
|
@ -31,6 +31,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -200,6 +202,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return filtered_search_results
|
||||
|
|
|
|||
29
cognee/eval_framework/Dockerfile
Normal file
29
cognee/eval_framework/Dockerfile
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
# Set environment variables
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
ENV PYTHONPATH=/app
|
||||
ENV SKIP_MIGRATIONS=true
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml poetry.lock README.md /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
COPY distributed/ /app/distributed
|
||||
|
|
@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
|
|||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
############
|
||||
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
|
||||
if isinstance(retrieval_context, list):
|
||||
retrieval_context = await retriever.convert_retrieved_objects_to_context(
|
||||
triplets=retrieval_context
|
||||
)
|
||||
|
||||
if isinstance(search_results, str):
|
||||
search_results = [search_results]
|
||||
#############
|
||||
answer = {
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
|
||||
|
||||
async def run_question_answering(
|
||||
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
|
||||
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
if params.get("answering_questions"):
|
||||
logger.info("Question answering started...")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
|
||||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
|
|
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
|
|||
"EM",
|
||||
"f1",
|
||||
] # Use only 'correctness' for DirectLLM
|
||||
deepeval_model: str = "gpt-5-mini"
|
||||
deepeval_model: str = "gpt-4o-mini"
|
||||
|
||||
# Metrics params
|
||||
calculate_metrics: bool = True
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import modal
|
|||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.eval_framework.eval_config import EvalConfig
|
||||
|
|
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
|
|||
from cognee.eval_framework.answer_generation.run_question_answering_module import (
|
||||
run_question_answering,
|
||||
)
|
||||
import pathlib
|
||||
from os import path
|
||||
from modal import Image
|
||||
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
|
||||
from cognee.eval_framework.metrics_dashboard import create_dashboard
|
||||
|
||||
|
|
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
|
|||
|
||||
app = modal.App("modal-run-eval")
|
||||
|
||||
image = (
|
||||
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
|
||||
.copy_local_file("pyproject.toml", "pyproject.toml")
|
||||
.copy_local_file("poetry.lock", "poetry.lock")
|
||||
.env(
|
||||
{
|
||||
"ENV": os.getenv("ENV"),
|
||||
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
|
||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
|
||||
}
|
||||
)
|
||||
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
|
||||
image = Image.from_dockerfile(
|
||||
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
|
||||
force_build=False,
|
||||
).add_local_python_source("cognee")
|
||||
|
||||
|
||||
@app.function(
|
||||
image=image,
|
||||
max_containers=10,
|
||||
timeout=86400,
|
||||
volumes={"/data": vol},
|
||||
secrets=[modal.Secret.from_name("eval_secrets")],
|
||||
)
|
||||
|
||||
|
||||
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
|
||||
async def modal_run_eval(eval_params=None):
|
||||
"""Runs evaluation pipeline and returns combined metrics results."""
|
||||
if eval_params is None:
|
||||
|
|
@ -105,18 +104,7 @@ async def main():
|
|||
configs = [
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
benchmark="HotPotQA",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
answering_questions=True,
|
||||
evaluating_answers=True,
|
||||
calculate_metrics=True,
|
||||
dashboard=True,
|
||||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="TwoWikiMultiHop",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
@ -127,7 +115,7 @@ async def main():
|
|||
),
|
||||
EvalConfig(
|
||||
task_getter_type="Default",
|
||||
number_of_samples_in_corpus=10,
|
||||
number_of_samples_in_corpus=25,
|
||||
benchmark="Musique",
|
||||
qa_engine="cognee_graph_completion",
|
||||
building_corpus_from_scratch=True,
|
||||
|
|
|
|||
|
|
@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
) -> Tuple[List[Node], List[EdgeData]]:
|
||||
"""
|
||||
Retrieve nodes and edges filtered by the provided attribute criteria.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
||||
are lists of attribute values to filter by.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
|
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
|
||||
tuples of (source_id, target_id, relationship_name, properties).
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
|
|
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
|
|
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
|
|||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
return formatted_nodes, formatted_edges
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
Returns:
|
||||
nodes: List[dict] -> Each dict includes "id" and all node properties
|
||||
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
if not all(isinstance(x, str) for x in target_ids):
|
||||
raise CogneeValidationError("target_ids must be a list of strings")
|
||||
|
||||
query = """
|
||||
MATCH (n:Node)-[r]->(m:Node)
|
||||
WHERE n.id IN $target_ids OR m.id IN $target_ids
|
||||
RETURN n.id, {
|
||||
name: n.name,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}, m.id, {
|
||||
name: m.name,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}, r.relationship_name, r.properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
if not result:
|
||||
logger.info("No data returned for the supplied IDs")
|
||||
return [], []
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
|
||||
if n_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(n_props["properties"])
|
||||
n_props.update(additional_props)
|
||||
del n_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {n_id}")
|
||||
|
||||
if m_props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(m_props["properties"])
|
||||
m_props.update(additional_props)
|
||||
del m_props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {m_id}")
|
||||
|
||||
nodes_dict[n_id] = (n_id, n_props)
|
||||
nodes_dict[m_id] = (m_id, m_props)
|
||||
|
||||
edge_props = {}
|
||||
if r_props_raw:
|
||||
try:
|
||||
edge_props = json.loads(r_props_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
|
||||
|
||||
source_id = edge_props.get("source_node_id", n_id)
|
||||
target_id = edge_props.get("target_node_id", m_id)
|
||||
edges.append((source_id, target_id, r_type, edge_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics on graph structure and connectivity.
|
||||
|
|
|
|||
|
|
@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_id_filtered_graph_data(self, target_ids: list[str]):
|
||||
"""
|
||||
Retrieve graph data filtered by specific node IDs, including their direct neighbors
|
||||
and only edges where one endpoint matches those IDs.
|
||||
|
||||
This version uses a single Cypher query for efficiency.
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if not target_ids:
|
||||
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
|
||||
return [], []
|
||||
|
||||
query = """
|
||||
MATCH ()-[r]-()
|
||||
WHERE startNode(r).id IN $target_ids
|
||||
OR endNode(r).id IN $target_ids
|
||||
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
|
||||
RETURN
|
||||
properties(a) AS n_properties,
|
||||
properties(b) AS m_properties,
|
||||
type(r) AS type,
|
||||
properties(r) AS properties
|
||||
"""
|
||||
|
||||
result = await self.query(query, {"target_ids": target_ids})
|
||||
|
||||
nodes_dict = {}
|
||||
edges = []
|
||||
|
||||
for record in result:
|
||||
n_props = record["n_properties"]
|
||||
m_props = record["m_properties"]
|
||||
r_props = record["properties"]
|
||||
r_type = record["type"]
|
||||
|
||||
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
|
||||
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
|
||||
|
||||
source_id = r_props.get("source_node_id", n_props["id"])
|
||||
target_id = r_props.get("target_node_id", m_props["id"])
|
||||
edges.append((source_id, target_id, r_type, r_props))
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
return list(nodes_dict.values()), edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
def get_edges(self) -> List[Edge]:
|
||||
return self.edges
|
||||
|
||||
async def _get_nodeset_subgraph(
|
||||
self,
|
||||
adapter,
|
||||
node_type,
|
||||
node_name,
|
||||
):
|
||||
"""Retrieve subgraph based on node type and name."""
|
||||
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodeset projected from the database."
|
||||
)
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_full_or_id_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
relevant_ids_to_filter,
|
||||
):
|
||||
"""Retrieve full or ID-filtered graph with fallback."""
|
||||
if relevant_ids_to_filter is None:
|
||||
logger.info("Retrieving full graph.")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
||||
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
||||
logger.info("Retrieving ID-filtered graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
||||
else:
|
||||
logger.info("Retrieving full graph from database.")
|
||||
nodes_data, edges_data = await get_graph_data_fn()
|
||||
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
||||
logger.warning(
|
||||
"Id filtered graph returned empty, falling back to full graph retrieval."
|
||||
)
|
||||
logger.info("Retrieving full graph")
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError("Empty graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def _get_filtered_graph(
|
||||
self,
|
||||
adapter,
|
||||
memory_fragment_filter,
|
||||
):
|
||||
"""Retrieve graph filtered by attributes."""
|
||||
logger.info("Retrieving graph filtered by memory fragment")
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
||||
return nodes_data, edges_data
|
||||
|
||||
async def project_graph_from_db(
|
||||
self,
|
||||
adapter: Union[GraphDBInterface],
|
||||
|
|
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
memory_fragment_filter=[],
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: float = 3.5,
|
||||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidDimensionsError()
|
||||
try:
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await self._get_nodeset_subgraph(
|
||||
adapter, node_type, node_name
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
|
||||
adapter, relevant_ids_to_filter
|
||||
)
|
||||
else:
|
||||
nodes_data, edges_data = await self._get_filtered_graph(
|
||||
adapter, memory_fragment_filter
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Nodeset does not exist, or empty nodetes projected from the database."
|
||||
)
|
||||
elif len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
self.add_node(
|
||||
Node(
|
||||
str(node_id),
|
||||
node_attributes,
|
||||
dimension=node_dimension,
|
||||
node_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
|
|
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
attributes=edge_attributes,
|
||||
directed=directed,
|
||||
dimension=edge_dimension,
|
||||
edge_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.add_edge(edge)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,17 @@ class Node:
|
|||
status: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
|
||||
self,
|
||||
node_id: str,
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
dimension: int = 1,
|
||||
node_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.id = node_id
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = node_penalty
|
||||
self.skeleton_neighbours = []
|
||||
self.skeleton_edges = []
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
|
@ -105,13 +109,14 @@ class Edge:
|
|||
attributes: Optional[Dict[str, Any]] = None,
|
||||
directed: bool = True,
|
||||
dimension: int = 1,
|
||||
edge_penalty: float = 3.5,
|
||||
):
|
||||
if dimension <= 0:
|
||||
raise InvalidDimensionsError()
|
||||
self.node1 = node1
|
||||
self.node2 = node2
|
||||
self.attributes = attributes if attributes is not None else {}
|
||||
self.attributes["vector_distance"] = float("inf")
|
||||
self.attributes["vector_distance"] = edge_penalty
|
||||
self.directed = directed
|
||||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
async def get_completion(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.validation_system_prompt_path = validation_system_prompt_path
|
||||
self.validation_user_prompt_path = validation_user_prompt_path
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.save_interaction = save_interaction
|
||||
|
|
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.system_prompt = system_prompt
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.wide_search_top_k = wide_search_top_k
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
self.triplet_distance_penalty = triplet_distance_penalty
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""
|
||||
|
|
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
wide_search_top_k=self.wide_search_top_k,
|
||||
triplet_distance_penalty=self.triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
|
@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
|
||||
return triplets
|
||||
|
||||
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
|
||||
context = await self.resolve_edges_to_text(triplets)
|
||||
return context
|
||||
|
||||
async def get_completion(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
"""Initialize retriever with default prompt paths and search parameters."""
|
||||
super().__init__(
|
||||
|
|
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.summarize_prompt_path = summarize_prompt_path
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
):
|
||||
super().__init__(
|
||||
user_prompt_path=user_prompt_path,
|
||||
|
|
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ async def get_memory_fragment(
|
|||
properties_to_project: Optional[List[str]] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
relevant_ids_to_filter: Optional[List[str]] = None,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
if properties_to_project is None:
|
||||
|
|
@ -74,6 +76,8 @@ async def get_memory_fragment(
|
|||
edge_properties_to_project=["relationship_name", "edge_text"],
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
|
|
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment: Optional[CogneeGraph] = None,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Edge]:
|
||||
"""
|
||||
Performs a brute force search to retrieve the top triplets from the graph.
|
||||
|
|
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
|
|||
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
|
||||
node_type: node type to filter
|
||||
node_name: node name to filter
|
||||
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
|
||||
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
|
||||
|
||||
Returns:
|
||||
list: The top triplet results.
|
||||
|
|
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
|
|||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer.")
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project, node_type=node_type, node_name=node_name
|
||||
)
|
||||
# Setting wide search limit based on the parameters
|
||||
non_global_search = node_name is None
|
||||
|
||||
wide_search_limit = wide_search_top_k if non_global_search else None
|
||||
|
||||
if collections is None:
|
||||
collections = [
|
||||
|
|
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
|
|||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
|
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
|
|||
return []
|
||||
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
vector_collection_search_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
||||
if wide_search_limit is not None:
|
||||
relevant_ids_to_filter = list(
|
||||
{
|
||||
str(getattr(scored_node, "id"))
|
||||
for collection_name, score_collection in node_distances.items()
|
||||
if collection_name != "EdgeType_relationship_name"
|
||||
and isinstance(score_collection, (list, tuple))
|
||||
for scored_node in score_collection
|
||||
if getattr(scored_node, "id", None)
|
||||
}
|
||||
)
|
||||
else:
|
||||
relevant_ids_to_filter = None
|
||||
|
||||
if memory_fragment is None:
|
||||
memory_fragment = await get_memory_fragment(
|
||||
properties_to_project=properties_to_project,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
relevant_ids_to_filter=relevant_ids_to_filter,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(
|
||||
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ async def get_search_type_tools(
|
|||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> list:
|
||||
search_tasks: dict[SearchType, List[Callable]] = {
|
||||
SearchType.SUMMARIES: [
|
||||
|
|
@ -67,6 +69,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -75,6 +79,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_COT: [
|
||||
|
|
@ -85,6 +91,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionCotRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -93,6 +101,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
|
||||
|
|
@ -103,6 +113,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphCompletionContextExtensionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -111,6 +123,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION: [
|
||||
|
|
@ -121,6 +135,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
GraphSummaryCompletionRetriever(
|
||||
system_prompt_path=system_prompt_path,
|
||||
|
|
@ -129,6 +145,8 @@ async def get_search_type_tools(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
system_prompt=system_prompt,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CODE: [
|
||||
|
|
@ -145,8 +163,16 @@ async def get_search_type_tools(
|
|||
],
|
||||
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
|
||||
SearchType.TEMPORAL: [
|
||||
TemporalRetriever(top_k=top_k).get_completion,
|
||||
TemporalRetriever(top_k=top_k).get_context,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_completion,
|
||||
TemporalRetriever(
|
||||
top_k=top_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
).get_context,
|
||||
],
|
||||
SearchType.CHUNKS_LEXICAL: (
|
||||
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ async def no_access_control_search(
|
|||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
|
|
@ -35,6 +37,8 @@ async def no_access_control_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
is_empty = await graph_engine.is_empty()
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ async def search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
"""
|
||||
|
||||
|
|
@ -90,6 +92,8 @@ async def search(
|
|||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
else:
|
||||
search_results = [
|
||||
|
|
@ -105,6 +109,8 @@ async def search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -219,6 +225,8 @@ async def authorized_search(
|
|||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
|
|
@ -246,6 +254,8 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
|
@ -267,6 +277,8 @@ async def authorized_search(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -306,6 +318,7 @@ async def authorized_search(
|
|||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
@ -325,6 +338,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
|
||||
"""
|
||||
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
|
||||
|
|
@ -345,6 +360,8 @@ async def search_in_datasets_context(
|
|||
only_context: bool = False,
|
||||
context: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
|
@ -378,6 +395,8 @@ async def search_in_datasets_context(
|
|||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
|
|
@ -413,6 +432,8 @@ async def search_in_datasets_context(
|
|||
only_context=only_context,
|
||||
context=context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_node_initialization():
|
|||
"""Test that a Node is initialized correctly."""
|
||||
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
||||
assert node.id == "node1"
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
|
||||
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
|
||||
assert len(node.status) == 2
|
||||
assert np.all(node.status == 1)
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ def test_edge_initialization():
|
|||
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
||||
assert edge.node1 == node1
|
||||
assert edge.node2 == node2
|
||||
assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
|
||||
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
|
||||
assert edge.directed is False
|
||||
assert len(edge.status) == 2
|
||||
assert np.all(edge.status == 1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
|
|
@ -11,6 +12,30 @@ def setup_graph():
|
|||
return CogneeGraph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""Fixture to create a mock adapter for database operations."""
|
||||
adapter = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_engine():
|
||||
"""Fixture to create a mock vector engine."""
|
||||
engine = AsyncMock()
|
||||
engine.search = AsyncMock()
|
||||
return engine
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
def test_add_node_success(setup_graph):
|
||||
"""Test successful addition of a node."""
|
||||
graph = setup_graph
|
||||
|
|
@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
|
|||
graph = setup_graph
|
||||
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||
graph.get_edges_from_node("nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting a full graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1", "description": "First node"}),
|
||||
("2", {"name": "Node2", "description": "Second node"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "description"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.get_node("1") is not None
|
||||
assert graph.get_node("2") is not None
|
||||
assert graph.edges[0].node1.id == "1"
|
||||
assert graph.edges[0].node2.id == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
|
||||
"""Test projecting an ID-filtered graph from database."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
("2", {"name": "Node2"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
relevant_ids_to_filter=["1", "2"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_id_filtered_graph_data.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
|
||||
"""Test projecting a nodeset subgraph filtered by node type and name."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Alice", "type": "Person"}),
|
||||
("2", {"name": "Bob", "type": "Person"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "2", "KNOWS", {"relationship_name": "knows"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name", "type"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
node_type="Person",
|
||||
node_name=["Alice"],
|
||||
)
|
||||
|
||||
assert len(graph.nodes) == 2
|
||||
assert graph.get_node("1") is not None
|
||||
assert len(graph.edges) == 1
|
||||
mock_adapter.get_nodeset_subgraph.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
|
||||
"""Test projecting empty graph raises EntityNotFoundError."""
|
||||
graph = setup_graph
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
|
||||
"""Test that edges referencing missing nodes raise error."""
|
||||
graph = setup_graph
|
||||
|
||||
nodes_data = [
|
||||
("1", {"name": "Node1"}),
|
||||
]
|
||||
edges_data = [
|
||||
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
|
||||
]
|
||||
|
||||
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
|
||||
|
||||
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
|
||||
await graph.project_graph_from_db(
|
||||
adapter=mock_adapter,
|
||||
node_properties_to_project=["name"],
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_nodes(setup_graph):
|
||||
"""Test mapping vector distances to graph nodes."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_node_coverage(setup_graph):
|
||||
"""Test mapping vector distances when only some nodes have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1", {"name": "Node1"})
|
||||
node2 = Node("2", {"name": "Node2"})
|
||||
node3 = Node("3", {"name": "Node3"})
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
]
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_multiple_categories(setup_graph):
|
||||
"""Test mapping vector distances from multiple collection categories."""
|
||||
graph = setup_graph
|
||||
|
||||
# Create nodes
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
node_distances = {
|
||||
"Entity_name": [
|
||||
MockScoredResult("1", 0.95),
|
||||
MockScoredResult("2", 0.87),
|
||||
],
|
||||
"TextSummary_text": [
|
||||
MockScoredResult("3", 0.92),
|
||||
],
|
||||
}
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(node_distances)
|
||||
|
||||
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
|
||||
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
|
||||
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
|
||||
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
|
||||
"""Test mapping vector distances to edges when edge_distances provided."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when searching for them."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
mock_vector_engine.search.return_value = [
|
||||
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
mock_vector_engine.search.assert_called_once()
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.88
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
|
||||
"""Test mapping edge distances when only some edges have results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
|
||||
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
|
||||
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.92
|
||||
assert graph.edges[1].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_edges_fallback_to_relationship_type(
|
||||
setup_graph, mock_vector_engine
|
||||
):
|
||||
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"relationship_type": "KNOWS"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
|
||||
"""Test edge mapping when no edges match the distance results."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(
|
||||
node1,
|
||||
node2,
|
||||
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[0.1, 0.2, 0.3],
|
||||
edge_distances=edge_distances,
|
||||
)
|
||||
|
||||
assert graph.edges[0].attributes.get("vector_distance") == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
|
||||
"""Test that invalid query vector raises error."""
|
||||
graph = setup_graph
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to generate query embedding"):
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
vector_engine=mock_vector_engine,
|
||||
query_vector=[],
|
||||
edge_distances=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances(setup_graph):
|
||||
"""Test calculating top triplet importances by score."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
node3 = Node("3")
|
||||
node4 = Node("4")
|
||||
|
||||
node1.add_attribute("vector_distance", 0.9)
|
||||
node2.add_attribute("vector_distance", 0.8)
|
||||
node3.add_attribute("vector_distance", 0.7)
|
||||
node4.add_attribute("vector_distance", 0.6)
|
||||
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
graph.add_node(node3)
|
||||
graph.add_node(node4)
|
||||
|
||||
edge1 = Edge(node1, node2)
|
||||
edge2 = Edge(node2, node3)
|
||||
edge3 = Edge(node3, node4)
|
||||
|
||||
edge1.add_attribute("vector_distance", 0.85)
|
||||
edge2.add_attribute("vector_distance", 0.75)
|
||||
edge3.add_attribute("vector_distance", 0.65)
|
||||
|
||||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
graph.add_edge(edge3)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=2)
|
||||
|
||||
assert len(top_triplets) == 2
|
||||
|
||||
assert top_triplets[0] == edge3
|
||||
assert top_triplets[1] == edge2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
|
||||
"""Test calculating importances when nodes/edges have no vector distances."""
|
||||
graph = setup_graph
|
||||
|
||||
node1 = Node("1")
|
||||
node2 = Node("2")
|
||||
graph.add_node(node1)
|
||||
graph.add_node(node2)
|
||||
|
||||
edge = Edge(node1, node2)
|
||||
graph.add_edge(edge)
|
||||
|
||||
top_triplets = await graph.calculate_top_triplet_importances(k=1)
|
||||
|
||||
assert len(top_triplets) == 1
|
||||
assert top_triplets[0] == edge
|
||||
|
|
|
|||
|
|
@ -0,0 +1,582 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_triplet_search,
|
||||
get_memory_fragment,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
|
||||
|
||||
class MockScoredResult:
|
||||
"""Mock class for vector search results."""
|
||||
|
||||
def __init__(self, id, score, payload=None):
|
||||
self.id = id
|
||||
self.score = score
|
||||
self.payload = payload or {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_empty_query():
|
||||
"""Test that empty query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query="")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_none_query():
|
||||
"""Test that None query raises ValueError."""
|
||||
with pytest.raises(ValueError, match="The query must be a non-empty string."):
|
||||
await brute_force_triplet_search(query=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_negative_top_k():
|
||||
"""Test that negative top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=-1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_zero_top_k():
|
||||
"""Test that zero top_k raises ValueError."""
|
||||
with pytest.raises(ValueError, match="top_k must be a positive integer."):
|
||||
await brute_force_triplet_search(query="test query", top_k=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_global_search():
|
||||
"""Test that wide_search_limit is applied for global search (node_name=None)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None, # Global search
|
||||
wide_search_top_k=75,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 75
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
|
||||
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=["Node1"],
|
||||
wide_search_top_k=50,
|
||||
)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_wide_search_default():
|
||||
"""Test that wide_search_top_k defaults to 100."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["limit"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_default_collections():
|
||||
"""Test that default collections are used when none provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test")
|
||||
|
||||
expected_collections = [
|
||||
"Entity_name",
|
||||
"TextSummary_text",
|
||||
"EntityType_name",
|
||||
"DocumentChunk_text",
|
||||
]
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == expected_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_custom_collections():
|
||||
"""Test that custom collections are used when provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
custom_collections = ["CustomCol1", "CustomCol2"]
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", collections=custom_collections)
|
||||
|
||||
call_collections = [
|
||||
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
|
||||
]
|
||||
assert call_collections == custom_collections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_all_collections_empty():
|
||||
"""Test that empty list is returned when all collections return no results."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
results = await brute_force_triplet_search(query="test")
|
||||
assert results == []
|
||||
|
||||
|
||||
# Tests for query embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_embeds_query():
|
||||
"""Test that query is embedded before searching."""
|
||||
query_text = "test query"
|
||||
expected_vector = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
):
|
||||
await brute_force_triplet_search(query=query_text)
|
||||
|
||||
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
|
||||
|
||||
for call in mock_vector_engine.search.call_args_list:
|
||||
assert call[1]["query_vector"] == expected_vector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
|
||||
"""Test that node IDs are extracted from search results for global search."""
|
||||
scored_results = [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=scored_results)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_reuses_provided_fragment():
|
||||
"""Test that provided memory fragment is reused instead of creating new one."""
|
||||
provided_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
memory_fragment=provided_fragment,
|
||||
node_name=["node"],
|
||||
)
|
||||
|
||||
mock_get_fragment.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
|
||||
"""Test that memory fragment is created when not provided."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=["node"])
|
||||
|
||||
mock_get_fragment.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
|
||||
"""Test that custom top_k is passed to importance calculation."""
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
),
|
||||
):
|
||||
custom_top_k = 15
|
||||
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
|
||||
|
||||
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
|
||||
"""Test that get_memory_fragment returns empty graph when entity not found."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(
|
||||
side_effect=EntityNotFoundError("Entity not found")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_memory_fragment_returns_empty_graph_on_error():
|
||||
"""Test that get_memory_fragment returns empty graph on generic error."""
|
||||
mock_graph_engine = AsyncMock()
|
||||
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
|
||||
|
||||
with patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
|
||||
return_value=mock_graph_engine,
|
||||
):
|
||||
fragment = await get_memory_fragment()
|
||||
|
||||
assert isinstance(fragment, CogneeGraph)
|
||||
assert len(fragment.nodes) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_deduplicates_node_ids():
|
||||
"""Test that duplicate node IDs across collections are deduplicated."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return [
|
||||
MockScoredResult("node1", 0.90),
|
||||
MockScoredResult("node3", 0.92),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
|
||||
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_excludes_edge_collection():
|
||||
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "EdgeType_relationship_name":
|
||||
return [MockScoredResult("edge1", 0.88)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(
|
||||
query="test",
|
||||
node_name=None,
|
||||
collections=["Entity_name", "EdgeType_relationship_name"],
|
||||
)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_skips_nodes_without_ids():
|
||||
"""Test that nodes without ID attribute are skipped."""
|
||||
|
||||
class ScoredResultNoId:
|
||||
"""Mock result without id attribute."""
|
||||
|
||||
def __init__(self, score):
|
||||
self.score = score
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [
|
||||
MockScoredResult("node1", 0.95),
|
||||
ScoredResultNoId(0.90),
|
||||
MockScoredResult("node2", 0.87),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_handles_tuple_results():
|
||||
"""Test that both list and tuple results are handled correctly."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return (
|
||||
MockScoredResult("node1", 0.95),
|
||||
MockScoredResult("node2", 0.87),
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brute_force_triplet_search_mixed_empty_collections():
|
||||
"""Test ID extraction with mixed empty and non-empty collections."""
|
||||
|
||||
def search_side_effect(*args, **kwargs):
|
||||
collection_name = kwargs.get("collection_name")
|
||||
if collection_name == "Entity_name":
|
||||
return [MockScoredResult("node1", 0.95)]
|
||||
elif collection_name == "TextSummary_text":
|
||||
return []
|
||||
elif collection_name == "EntityType_name":
|
||||
return [MockScoredResult("node2", 0.92)]
|
||||
else:
|
||||
return []
|
||||
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine = AsyncMock()
|
||||
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
|
||||
|
||||
mock_fragment = AsyncMock(
|
||||
map_vector_distances_to_graph_nodes=AsyncMock(),
|
||||
map_vector_distances_to_graph_edges=AsyncMock(),
|
||||
calculate_top_triplet_importances=AsyncMock(return_value=[]),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
|
||||
return_value=mock_vector_engine,
|
||||
),
|
||||
patch(
|
||||
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
|
||||
return_value=mock_fragment,
|
||||
) as mock_get_fragment_fn,
|
||||
):
|
||||
await brute_force_triplet_search(query="test", node_name=None)
|
||||
|
||||
call_kwargs = mock_get_fragment_fn.call_args[1]
|
||||
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
|
||||
Loading…
Add table
Reference in a new issue