cognee/cognee/infrastructure/databases/graph/graph_db_interface.py
hajdul88 508165e883
feature: Introduces wide subgraph search in graph completion and improves QA speed (#1736)
<!-- .github/pull_request_template.md -->

This PR introduces wide vector and graph structure filtering
capabilities. With these changes, the graph completion retriever and all
retrievers that inherit from it will now filter relevant vector elements
and subgraphs based on the query. This improvement significantly
increases search speed for large graphs while maintaining—and in some
cases slightly improving—accuracy.

Changes in This PR:

-Introduced new wide_search_top_k parameter: Controls the initial search
space size

-Added graph adapter level filtering method: Enables relevant subgraph
filtering while maintaining backward compatibility. For community or
custom graph adapters that don't implement this method, the system
gracefully falls back to the original search behavior.

-Updated modal dashboard and evaluation framework: Fixed compatibility
issues.
Added comprehensive unit tests: Introduced unit tests for
brute_force_triplet_search (previously untested) and expanded the
CogneeGraph test suite.

Integration tests: Existing integration tests verify end-to-end search
functionality (no changes required).

Acceptance Criteria and Testing

To verify the new search behavior, run search queries with different
wide_search_top_k parameters while logging is enabled:
None: Triggers a full graph search (default behavior)
1: Projects a minimal subgraph (demonstrates maximum filtering)
Custom values: Test intermediate levels of filtering

Internal Testing and results:
Performance and accuracy benchmarks are available upon request. The
implementation demonstrates measurable improvements in query latency for
large graphs without sacrificing result quality.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
None

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
2025-11-26 15:18:53 +01:00

415 lines
13 KiB
Python

import inspect
from functools import wraps
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from typing import Optional, Dict, Any, List, Tuple, Type, Union
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
logger = get_logger()
# Type aliases for better readability
NodeData = Dict[str, Any]
EdgeData = Tuple[
str, str, str, Dict[str, Any]
] # (source_id, target_id, relationship_name, properties)
Node = Tuple[str, NodeData] # (node_id, properties)
def record_graph_changes(func):
"""
Decorator to record graph changes in the relationship database.
Parameters:
-----------
- func: The asynchronous function to wrap, which likely modifies graph data.
Returns:
--------
Returns the wrapped function that manages database relationships.
"""
@wraps(func)
async def wrapper(self, *args, **kwargs):
"""
Wraps the given asynchronous function to handle database relationships.
Tracks the caller's function and class name for context. When the wrapped function is
called, it manages database relationships for nodes or edges by adding entries to a
ledger and committing the changes to the database session. Errors during relationship
addition or session commit are logged and will not disrupt the execution of the wrapped
function.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function call.
"""
db_engine = get_relational_engine()
frame = inspect.currentframe()
while frame:
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
caller_frame = frame.f_back
break
frame = frame.f_back
caller_name = caller_frame.f_code.co_name
caller_class = (
caller_frame.f_locals.get("self", None).__class__.__name__
if caller_frame.f_locals.get("self", None)
else None
)
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
result = await func(self, *args, **kwargs)
async with db_engine.get_async_session() as session:
if func.__name__ == "add_nodes":
nodes: List[DataPoint] = args[0]
relationship_ledgers = []
for node in nodes:
node_id = UUID(str(node.id))
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=getattr(node, "name", None) or str(node.id),
)
)
try:
session.add_all(relationship_ledgers)
await session.flush()
except Exception as e:
logger.debug(f"Error adding relationship: {e}")
await session.rollback()
elif func.__name__ == "add_edges":
edges = args[0]
relationship_ledgers = []
for edge in edges:
source_id = UUID(str(edge[0]))
target_id = UUID(str(edge[1]))
rel_type = str(edge[2])
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",
)
)
try:
session.add_all(relationship_ledgers)
await session.flush()
except Exception as e:
logger.debug(f"Error adding relationship: {e}")
await session.rollback()
try:
await session.commit()
except Exception as e:
logger.debug(f"Error committing session: {e}")
return result
return wrapper
class GraphDBInterface(ABC):
"""
Define an interface for graph database operations to be implemented by concrete classes.
Public methods include:
- query
- add_node
- add_nodes
- delete_node
- delete_nodes
- get_node
- get_nodes
- add_edge
- add_edges
- delete_graph
- get_graph_data
- get_graph_metrics
- has_edge
- has_edges
- get_edges
- get_neighbors
- get_nodeset_subgraph
- get_connections
"""
@abstractmethod
async def is_empty(self) -> bool:
logger.warning("is_empty() is not implemented")
return True
@abstractmethod
async def query(self, query: str, params: dict) -> List[Any]:
"""
Execute a raw database query and return the results.
Parameters:
-----------
- query (str): The query string to execute against the database.
- params (dict): A dictionary of parameters to be used in the query.
"""
raise NotImplementedError
@abstractmethod
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a single node with specified properties to the graph.
Parameters:
-----------
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
Required when node is a string, ignored when node is a DataPoint.
"""
raise NotImplementedError
@abstractmethod
@record_graph_changes
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
"""
Add multiple nodes to the graph in a single operation.
Parameters:
-----------
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
"""
raise NotImplementedError
@abstractmethod
async def delete_node(self, node_id: str) -> None:
"""
Delete a specified node from the graph by its ID.
Parameters:
-----------
- node_id (str): Unique identifier for the node to delete.
"""
raise NotImplementedError
@abstractmethod
async def delete_nodes(self, node_ids: List[str]) -> None:
"""
Delete multiple nodes from the graph by their identifiers.
Parameters:
-----------
- node_ids (List[str]): A list of unique identifiers for the nodes to delete.
"""
raise NotImplementedError
@abstractmethod
async def get_node(self, node_id: str) -> Optional[NodeData]:
"""
Retrieve a single node from the graph using its ID.
Parameters:
-----------
- node_id (str): Unique identifier of the node to retrieve.
"""
raise NotImplementedError
@abstractmethod
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
"""
Retrieve multiple nodes from the graph using their IDs.
Parameters:
-----------
- node_ids (List[str]): A list of unique identifiers for the nodes to retrieve.
"""
raise NotImplementedError
@abstractmethod
async def add_edge(
self,
source_id: str,
target_id: str,
relationship_name: str,
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Create a new edge between two nodes in the graph.
Parameters:
-----------
- source_id (str): The unique identifier of the source node.
- target_id (str): The unique identifier of the target node.
- relationship_name (str): The name of the relationship to be established by the
edge.
- properties (Optional[Dict[str, Any]]): Optional dictionary of properties
associated with the edge. (default None)
"""
raise NotImplementedError
@abstractmethod
@record_graph_changes
async def add_edges(
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
) -> None:
"""
Add multiple edges to the graph in a single operation.
Parameters:
-----------
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
"""
raise NotImplementedError
@abstractmethod
async def delete_graph(self) -> None:
"""
Remove the entire graph, including all nodes and edges.
"""
raise NotImplementedError
@abstractmethod
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve all nodes and edges within the graph.
"""
raise NotImplementedError
@abstractmethod
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
"""
Fetch metrics and statistics of the graph, possibly including optional details.
Parameters:
-----------
- include_optional (bool): Flag indicating whether to include optional metrics or
not. (default False)
"""
raise NotImplementedError
@abstractmethod
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
"""
Verify if an edge exists between two specified nodes.
Parameters:
-----------
- source_id (str): Unique identifier of the source node.
- target_id (str): Unique identifier of the target node.
- relationship_name (str): Name of the relationship to verify.
"""
raise NotImplementedError
@abstractmethod
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
"""
Determine the existence of multiple edges in the graph.
Parameters:
-----------
- edges (List[EdgeData]): A list of EdgeData objects to check for existence in the
graph.
"""
raise NotImplementedError
@abstractmethod
async def get_edges(self, node_id: str) -> List[EdgeData]:
"""
Retrieve all edges that are connected to the specified node.
Parameters:
-----------
- node_id (str): Unique identifier of the node whose edges are to be retrieved.
"""
raise NotImplementedError
@abstractmethod
async def get_neighbors(self, node_id: str) -> List[NodeData]:
"""
Get all neighboring nodes connected to the specified node.
Parameters:
-----------
- node_id (str): Unique identifier of the node for which to retrieve neighbors.
"""
raise NotImplementedError
@abstractmethod
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
"""
Fetch a subgraph consisting of a specific set of nodes and their relationships.
Parameters:
-----------
- node_type (Type[Any]): The type of nodes to include in the subgraph.
- node_name (List[str]): A list of names of the nodes to include in the subgraph.
"""
raise NotImplementedError
@abstractmethod
async def get_connections(
self, node_id: Union[str, UUID]
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
"""
Get all nodes connected to a specified node and their relationship details.
Parameters:
-----------
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
@abstractmethod
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve nodes and edges filtered by the provided attribute criteria.
Parameters:
-----------
- attribute_filters: A list of dictionaries where keys are attribute names and values
are lists of attribute values to filter by.
"""
raise NotImplementedError