Feat: Adds subgraph retriever to graph based completion searches (#874)

<!-- .github/pull_request_template.md -->

## Description
Adds subgraph retriever to graph based completion searches

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
hajdul88 2025-05-27 11:40:47 +02:00 committed by GitHub
parent 834d959b11
commit 965033e161
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 302 additions and 31 deletions

View file

@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional, List, Type
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
@ -13,6 +13,8 @@ async def search(
datasets: Union[list[str], str, None] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
@ -22,12 +24,14 @@ async def search(
user = await get_default_user()
filtered_search_results = await search_function(
query_text,
query_type,
datasets,
user,
query_text=query_text,
query_type=query_type,
datasets=datasets,
user=user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
return filtered_search_results

View file

@ -37,3 +37,17 @@ class EntityAlreadyExistsError(CogneeApiError):
status_code=status.HTTP_409_CONFLICT,
):
super().__init__(message, name, status_code)
class NodesetFilterNotSupportedError(CogneeApiError):
"""Nodeset filter is not supported by the current database"""
def __init__(
self,
message: str = "The nodeset filter is not supported in the current graph database.",
name: str = "NodeSetFilterNotSupportedError",
status_code=status.HTTP_404_NOT_FOUND,
):
self.message = message
self.name = name
self.status_code = status_code

View file

@ -2,7 +2,7 @@ import inspect
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple
from typing import Optional, Dict, Any, List, Tuple, Type
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
@ -183,6 +183,13 @@ class GraphDBInterface(ABC):
"""Get all neighboring nodes."""
raise NotImplementedError
@abstractmethod
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NotImplementedError
@abstractmethod
async def get_connections(
self, node_id: str

View file

@ -1,11 +1,12 @@
"""Adapter for Kuzu graph database."""
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger
import json
import os
import shutil
import asyncio
from typing import Dict, Any, List, Union, Optional, Tuple
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from datetime import datetime, timezone
from uuid import UUID
from contextlib import asynccontextmanager
@ -728,6 +729,65 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to get graph data: {e}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
label = node_type.__name__
primary_query = """
UNWIND $names AS wantedName
MATCH (n:Node)
WHERE n.type = $label AND n.name = wantedName
RETURN DISTINCT n.id
"""
primary_rows = await self.query(primary_query, {"names": node_name, "label": label})
primary_ids = [row[0] for row in primary_rows]
if not primary_ids:
return [], []
neighbor_query = """
MATCH (n:Node)-[:EDGE]-(nbr:Node)
WHERE n.id IN $ids
RETURN DISTINCT nbr.id
"""
nbr_rows = await self.query(neighbor_query, {"ids": primary_ids})
neighbor_ids = [row[0] for row in nbr_rows]
all_ids = list({*primary_ids, *neighbor_ids})
nodes_query = """
MATCH (n:Node)
WHERE n.id IN $ids
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for node {node_id}")
nodes.append((node_id, data))
edges_query = """
MATCH (a:Node)-[r:EDGE]-(b:Node)
WHERE a.id IN $ids AND b.id IN $ids
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
if props:
try:
data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))
return nodes, edges
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):

View file

@ -4,7 +4,7 @@ import json
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
@ -13,6 +13,7 @@ from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.storage.utils import JSONEncoder
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
logger = get_logger("MemgraphAdapter", level=ERROR)
@ -482,6 +483,12 @@ class MemgraphAdapter(GraphDBInterface):
return (nodes, edges)
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.

View file

@ -6,7 +6,7 @@ import json
from cognee.shared.logging_utils import get_logger, ERROR
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Type, Tuple
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
@ -517,6 +517,54 @@ class Neo4jAdapter(GraphDBInterface):
return (nodes, edges)
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
label = node_type.__name__
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
result = await self.query(query, {"names": node_name})
if not result:
return [], []
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
for r in raw_rels
]
return nodes, edges
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.

View file

@ -4,8 +4,10 @@ from datetime import datetime, timezone
import os
import json
import asyncio
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger
from typing import Dict, Any, List, Union
from typing import Dict, Any, List, Union, Type, Tuple
from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
@ -396,6 +398,12 @@ class NetworkXAdapter(GraphDBInterface):
logger.error("Failed to delete graph: %s", error)
raise error
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""Get nodeset subgraph"""
raise NodesetFilterNotSupportedError
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):

View file

@ -5,4 +5,3 @@ class NodeSet(DataPoint):
"""NodeSet data point."""
name: str
metadata: dict = {"index_fields": ["name"]}

View file

@ -1,5 +1,5 @@
from cognee.shared.logging_utils import get_logger
from typing import List, Dict, Union
from typing import List, Dict, Union, Optional, Type
from cognee.exceptions import InvalidValueError
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
@ -61,22 +61,33 @@ class CogneeGraph(CogneeAbstractGraph):
node_dimension=1,
edge_dimension=1,
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers")
try:
if len(memory_fragment_filter) == 0:
if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data:
raise EntityNotFoundError(message="No node data retrieved from the database.")
if not edges_data:
raise EntityNotFoundError(message="No edge data retrieved from the database.")
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}

View file

@ -95,6 +95,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name,
description=ont_node_name,
ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
)
for source, relation, target in ontology_entity_type_edges:
@ -144,6 +145,7 @@ def expand_with_nodes_and_edges(
is_a=type_node,
description=node.description,
ontology_valid=ontology_validated_source_ent,
belongs_to_set=data_chunk.belongs_to_set,
)
added_nodes_map[entity_node_key] = entity_node
@ -174,6 +176,7 @@ def expand_with_nodes_and_edges(
name=ont_node_name,
description=ont_node_name,
ontology_valid=True,
belongs_to_set=data_chunk.belongs_to_set,
)
for source, relation, target in ontology_entity_edges:

View file

@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -14,11 +14,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
async def get_completion(

View file

@ -1,4 +1,4 @@
from typing import Any, Optional, List
from typing import Any, Optional, List, Type
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -18,11 +18,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Optional, Type, List
from collections import Counter
import string
@ -18,11 +18,15 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content."""
@ -68,7 +72,11 @@ class GraphCompletionRetriever(BaseRetriever):
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None
query,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
)
return found_triplets

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Type, List
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import summarize_text
@ -13,12 +13,16 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
system_prompt_path: str = "answer_simple_question.txt",
summarize_prompt_path: str = "summarize_search_results.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -1,5 +1,5 @@
import asyncio
from typing import List, Optional
from typing import List, Optional, Type
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
@ -55,6 +55,8 @@ def format_triplets(edges):
async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
@ -68,6 +70,8 @@ async def get_memory_fragment(
graph_engine,
node_properties_to_project=properties_to_project,
edge_properties_to_project=["relationship_name"],
node_type=node_type,
node_name=node_name,
)
except EntityNotFoundError:
pass
@ -82,6 +86,8 @@ async def brute_force_triplet_search(
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
if user is None:
user = await get_default_user()
@ -93,6 +99,8 @@ async def brute_force_triplet_search(
collections=collections,
properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
node_type=node_type,
node_name=node_name,
)
return retrieved_results
@ -104,6 +112,8 @@ async def brute_force_search(
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -125,7 +135,9 @@ async def brute_force_search(
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(properties_to_project)
memory_fragment = await get_memory_fragment(
properties_to_project, node_type=node_type, node_name=node_name
)
if collections is None:
collections = [

View file

@ -1,5 +1,5 @@
import json
from typing import Callable
from typing import Callable, Optional, List, Type
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine.utils import parse_id
@ -33,12 +33,20 @@ async def search(
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
query = await log_query(query_text, query_type.value, user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
)
filtered_search_results = []
@ -61,29 +69,39 @@ async def specific_search(
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,

View file

@ -2,6 +2,9 @@ import os
import shutil
import cognee
import pathlib
from cognee.modules.engine.models import NodeSet
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import get_history
@ -84,6 +87,30 @@ async def main():
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
nodeset_text = "Neo4j is a graph database that supports cypher."
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -1,10 +1,12 @@
import os
import pathlib
import cognee
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.engine.models import NodeSet
logger = get_logger()
@ -89,6 +91,30 @@ async def main():
assert len(history) == 6, "Search history is not correct."
nodeset_text = "Neo4j is a graph database that supports cypher."
await cognee.add([nodeset_text], dataset_name, node_set=["first"])
await cognee.cognify([dataset_name])
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pylint.checkers.utils import node_type
from cognee.exceptions import InvalidValueError
from cognee.modules.search.methods.search import search, specific_search
@ -68,7 +69,13 @@ async def test_search(
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_get_document_ids.assert_called_once_with(mock_user.id, datasets)
mock_specific_search.assert_called_once_with(
query_type, query_text, mock_user, system_prompt_path="answer_simple_question.txt", top_k=10
query_type,
query_text,
mock_user,
system_prompt_path="answer_simple_question.txt",
top_k=10,
node_type=None,
node_name=None,
)
# Only the first two results should be included (doc_id3 is filtered out)