Feat: Adds subgraph retriever to graph based completion searches (#874)

<!-- .github/pull_request_template.md -->

## Description
Adds subgraph retriever to graph based completion searches

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
hajdul88 2025-05-27 11:40:47 +02:00 committed by GitHub
parent 834d959b11
commit 965033e161
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 302 additions and 31 deletions

View file

@ -1,4 +1,4 @@
from typing import Union from typing import Union, Optional, List, Type
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
@ -13,6 +13,8 @@ async def search(
datasets: Union[list[str], str, None] = None, datasets: Union[list[str], str, None] = None,
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list: ) -> list:
# We use lists from now on for datasets # We use lists from now on for datasets
if isinstance(datasets, str): if isinstance(datasets, str):
@ -22,12 +24,14 @@ async def search(
user = await get_default_user() user = await get_default_user()
filtered_search_results = await search_function( filtered_search_results = await search_function(
query_text, query_text=query_text,
query_type, query_type=query_type,
datasets, datasets=datasets,
user, user=user,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
) )
return filtered_search_results return filtered_search_results

View file

@ -37,3 +37,17 @@ class EntityAlreadyExistsError(CogneeApiError):
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
): ):
super().__init__(message, name, status_code) super().__init__(message, name, status_code)
class NodesetFilterNotSupportedError(CogneeApiError):
"""Nodeset filter is not supported by the current database"""
def __init__(
self,
message: str = "The nodeset filter is not supported in the current graph database.",
name: str = "NodeSetFilterNotSupportedError",
status_code=status.HTTP_404_NOT_FOUND,
):
self.message = message
self.name = name
self.status_code = status_code

View file

@ -2,7 +2,7 @@ import inspect
from functools import wraps from functools import wraps
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple from typing import Optional, Dict, Any, List, Tuple, Type
from uuid import NAMESPACE_OID, UUID, uuid5 from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -183,6 +183,13 @@ class GraphDBInterface(ABC):
"""Get all neighboring nodes.""" """Get all neighboring nodes."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NotImplementedError
@abstractmethod @abstractmethod
async def get_connections( async def get_connections(
self, node_id: str self, node_id: str

View file

@ -1,11 +1,12 @@
"""Adapter for Kuzu graph database.""" """Adapter for Kuzu graph database."""
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
import json import json
import os import os
import shutil import shutil
import asyncio import asyncio
from typing import Dict, Any, List, Union, Optional, Tuple from typing import Dict, Any, List, Union, Optional, Tuple, Type
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID from uuid import UUID
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -728,6 +729,65 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to get graph data: {e}") logger.error(f"Failed to get graph data: {e}")
raise raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
label = node_type.__name__
primary_query = """
UNWIND $names AS wantedName
MATCH (n:Node)
WHERE n.type = $label AND n.name = wantedName
RETURN DISTINCT n.id
"""
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
primary_ids = [row[0] for row in primary_rows]
if not primary_ids:
return [], []
neighbor_query = """
MATCH (n:Node)-[:EDGE]-(nbr:Node)
WHERE n.id IN $ids
RETURN DISTINCT nbr.id
"""
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
neighbor_ids = [row[0] for row in nbr_rows]
all_ids = list({*primary_ids, *neighbor_ids})
nodes_query = """
MATCH (n:Node)
WHERE n.id IN $ids
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for node {node_id}")
nodes.append((node_id, data))
edges_query = """
MATCH (a:Node)-[r:EDGE]-(b:Node)
WHERE a.id IN $ids AND b.id IN $ids
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
if props:
try:
data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))
return nodes, edges
async def get_filtered_graph_data( async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]] self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
): ):

View file

@ -4,7 +4,7 @@ import json
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
@ -13,6 +13,7 @@ from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
logger = get_logger("MemgraphAdapter", level=ERROR) logger = get_logger("MemgraphAdapter", level=ERROR)
@ -482,6 +483,12 @@ class MemgraphAdapter(GraphDBInterface):
return (nodes, edges) return (nodes, edges)
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError
async def get_filtered_graph_data(self, attribute_filters): async def get_filtered_graph_data(self, attribute_filters):
""" """
Fetches nodes and relationships filtered by specified attribute values. Fetches nodes and relationships filtered by specified attribute values.

View file

@ -6,7 +6,7 @@ import json
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
@ -517,6 +517,54 @@ class Neo4jAdapter(GraphDBInterface):
return (nodes, edges) return (nodes, edges)
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
label = node_type.__name__
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
result = await self.query(query, {"names": node_name})
if not result:
return [], []
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
for r in raw_rels
]
return nodes, edges
async def get_filtered_graph_data(self, attribute_filters): async def get_filtered_graph_data(self, attribute_filters):
""" """
Fetches nodes and relationships filtered by specified attribute values. Fetches nodes and relationships filtered by specified attribute values.

View file

@ -4,8 +4,10 @@ from datetime import datetime, timezone
import os import os
import json import json
import asyncio import asyncio
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from typing import Dict, Any, List, Union from typing import Dict, Any, List, Union, Type, Tuple
from uuid import UUID from uuid import UUID
import aiofiles import aiofiles
import aiofiles.os as aiofiles_os import aiofiles.os as aiofiles_os
@ -396,6 +398,12 @@ class NetworkXAdapter(GraphDBInterface):
logger.error("Failed to delete graph: %s", error) logger.error("Failed to delete graph: %s", error)
raise error raise error
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError
async def get_filtered_graph_data( async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]] self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
): ):

View file

@ -5,4 +5,3 @@ class NodeSet(DataPoint):
"""NodeSet data point.""" """NodeSet data point."""
name: str name: str
metadata: dict = {"index_fields": ["name"]}

View file

@ -1,5 +1,5 @@
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union from typing import List, Dict, Union, Optional, Type
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
@ -61,22 +61,33 @@ class CogneeGraph(CogneeAbstractGraph):
node_dimension=1, node_dimension=1,
edge_dimension=1, edge_dimension=1,
memory_fragment_filter=[], memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> None: ) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers") raise InvalidValueError(message="Dimensions must be positive integers")
try: try:
if len(memory_fragment_filter) == 0: if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data() nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else: else:
nodes_data, edges_data = await adapter.get_filtered_graph_data( nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter attribute_filters=memory_fragment_filter
) )
if not nodes_data: if not nodes_data or not edges_data:
raise EntityNotFoundError(message="No node data retrieved from the database.") raise EntityNotFoundError(
if not edges_data: message="Empty filtered graph projected from the database."
raise EntityNotFoundError(message="No edge data retrieved from the database.") )
for node_id, properties in nodes_data: for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project} node_attributes = {key: properties.get(key) for key in node_properties_to_project}

View file

@ -95,6 +95,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name, name=ont_node_name,
description=ont_node_name, description=ont_node_name,
ontology_valid=True, ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
) )
for source, relation, target in ontology_entity_type_edges: for source, relation, target in ontology_entity_type_edges:
@ -144,6 +145,7 @@ def expand_with_nodes_and_edges(
is_a=type_node, is_a=type_node,
description=node.description, description=node.description,
ontology_valid=ontology_validated_source_ent, ontology_valid=ontology_validated_source_ent,
belongs_to_set=data_chunk.belongs_to_set,
) )
added_nodes_map[entity_node_key] = entity_node added_nodes_map[entity_node_key] = entity_node
@ -174,6 +176,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name, name=ont_node_name,
description=ont_node_name, description=ont_node_name,
ontology_valid=True, ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
) )
for source, relation, target in ontology_entity_edges: for source, relation, target in ontology_entity_edges:

View file

@ -1,4 +1,4 @@
from typing import Any, Optional, List from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -14,11 +14,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
) )
async def get_completion( async def get_completion(

View file

@ -1,4 +1,4 @@
from typing import Any, Optional, List from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -18,11 +18,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
): ):
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
) )
self.validation_system_prompt_path = validation_system_prompt_path self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path self.validation_user_prompt_path = validation_user_prompt_path

View file

@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any, Optional, Type, List
from collections import Counter from collections import Counter
import string import string
@ -18,11 +18,15 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path: str = "graph_context_for_question.txt", user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
): ):
"""Initialize retriever with prompt paths and search parameters.""" """Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5 self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
def _get_nodes(self, retrieved_edges: list) -> dict: def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content.""" """Creates a dictionary of nodes with their names and content."""
@ -68,7 +72,11 @@ class GraphCompletionRetriever(BaseRetriever):
vector_index_collections.append(f"{subclass.__name__}_{field_name}") vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search( found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None query,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
) )
return found_triplets return found_triplets

View file

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Type, List
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import summarize_text from cognee.modules.retrieval.utils.completion import summarize_text
@ -13,12 +13,16 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
system_prompt_path: str = "answer_simple_question.txt", system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt", summarize_prompt_path: str = "summarize_search_results.txt",
top_k: Optional[int] = 5, top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
): ):
"""Initialize retriever with default prompt paths and search parameters.""" """Initialize retriever with default prompt paths and search parameters."""
super().__init__( super().__init__(
user_prompt_path=user_prompt_path, user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
) )
self.summarize_prompt_path = summarize_prompt_path self.summarize_prompt_path = summarize_prompt_path

View file

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional, Type
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
@ -55,6 +55,8 @@ def format_triplets(edges):
async def get_memory_fragment( async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None, properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> CogneeGraph: ) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections.""" """Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
@ -68,6 +70,8 @@ async def get_memory_fragment(
graph_engine, graph_engine,
node_properties_to_project=properties_to_project, node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"], edge_properties_to_project=["relationship_name"],
node_type=node_type,
node_name=node_name,
) )
except EntityNotFoundError: except EntityNotFoundError:
pass pass
@ -82,6 +86,8 @@ async def brute_force_triplet_search(
collections: List[str] = None, collections: List[str] = None,
properties_to_project: List[str] = None, properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None, memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list: ) -> list:
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
@ -93,6 +99,8 @@ async def brute_force_triplet_search(
collections=collections, collections=collections,
properties_to_project=properties_to_project, properties_to_project=properties_to_project,
memory_fragment=memory_fragment, memory_fragment=memory_fragment,
node_type=node_type,
node_name=node_name,
) )
return retrieved_results return retrieved_results
@ -104,6 +112,8 @@ async def brute_force_search(
collections: List[str] = None, collections: List[str] = None,
properties_to_project: List[str] = None, properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None, memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list: ) -> list:
""" """
Performs a brute force search to retrieve the top triplets from the graph. Performs a brute force search to retrieve the top triplets from the graph.
@ -125,7 +135,9 @@ async def brute_force_search(
raise ValueError("top_k must be a positive integer.") raise ValueError("top_k must be a positive integer.")
if memory_fragment is None: if memory_fragment is None:
memory_fragment = await get_memory_fragment(properties_to_project) memory_fragment = await get_memory_fragment(
properties_to_project, node_type=node_type, node_name=node_name
)
if collections is None: if collections is None:
collections = [ collections = [

View file

@ -1,5 +1,5 @@
import json import json
from typing import Callable from typing import Callable, Optional, List, Type
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
@ -33,12 +33,20 @@ async def search(
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
): ):
query = await log_query(query_text, query_type.value, user.id) query = await log_query(query_text, query_type.value, user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets) own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search( search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
) )
filtered_search_results = [] filtered_search_results = []
@ -61,29 +69,39 @@ async def specific_search(
user: User, user: User,
system_prompt_path="answer_simple_question.txt", system_prompt_path="answer_simple_question.txt",
top_k: int = 10, top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list: ) -> list:
search_tasks: dict[SearchType, Callable] = { search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever( SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path, top_k=top_k
top_k=top_k,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion, ).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path, system_prompt_path=system_prompt_path,
top_k=top_k, top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion, ).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion, ).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.CYPHER: CypherSearchRetriever().get_completion,

View file

@ -2,6 +2,9 @@ import os
import shutil import shutil
import cognee import cognee
import pathlib import pathlib
from cognee.modules.engine.models import NodeSet
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import get_history from cognee.modules.search.operations import get_history
@ -84,6 +87,30 @@ async def main():
history = await get_history(user.id) history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
nodeset_text = "Neo4j is a graph database that supports cypher."
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -1,10 +1,12 @@
import os import os
import pathlib import pathlib
import cognee import cognee
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.search.operations import get_history from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType
from cognee.modules.engine.models import NodeSet
logger = get_logger() logger = get_logger()
@ -89,6 +91,30 @@ async def main():
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
nodeset_text = "Neo4j is a graph database that supports cypher."
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from pylint.checkers.utils import node_type
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.modules.search.methods.search import search, specific_search from cognee.modules.search.methods.search import search, specific_search
@ -68,7 +69,13 @@ async def test_search(
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets) mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
mock_specific_search.assert_called_once_with( mock_specific_search.assert_called_once_with(
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10 query_type,
query_text,
mock_user,
system_prompt_path="answer_simple_question.txt",
top_k=10,
node_type=None,
node_name=None,
) )
# Only the first two results should be included (doc_id3 is filtered out) # Only the first two results should be included (doc_id3 is filtered out)