<!-- .github/pull_request_template.md --> This PR introduces wide vector and graph structure filtering capabilities. With these changes, the graph completion retriever and all retrievers that inherit from it will now filter relevant vector elements and subgraphs based on the query. This improvement significantly increases search speed for large graphs while maintaining—and in some cases slightly improving—accuracy. Changes in This PR: -Introduced new wide_search_top_k parameter: Controls the initial search space size -Added graph adapter level filtering method: Enables relevant subgraph filtering while maintaining backward compatibility. For community or custom graph adapters that don't implement this method, the system gracefully falls back to the original search behavior. -Updated modal dashboard and evaluation framework: Fixed compatibility issues. Added comprehensive unit tests: Introduced unit tests for brute_force_triplet_search (previously untested) and expanded the CogneeGraph test suite. Integration tests: Existing integration tests verify end-to-end search functionality (no changes required). Acceptance Criteria and Testing To verify the new search behavior, run search queries with different wide_search_top_k parameters while logging is enabled: None: Triggers a full graph search (default behavior) 1: Projects a minimal subgraph (demonstrates maximum filtering) Custom values: Test intermediate levels of filtering Internal Testing and results: Performance and accuracy benchmarks are available upon request. The implementation demonstrates measurable improvements in query latency for large graphs without sacrificing result quality. ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Pavel Zorin <pazonec@yandex.ru>
415 lines
13 KiB
Python
415 lines
13 KiB
Python
import inspect
|
|
from functools import wraps
|
|
from abc import abstractmethod, ABC
|
|
from datetime import datetime, timezone
|
|
from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
|
from uuid import NAMESPACE_OID, UUID, uuid5
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
|
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
|
|
|
logger = get_logger()
|
|
|
|
# Type aliases for better readability
|
|
NodeData = Dict[str, Any]
|
|
EdgeData = Tuple[
|
|
str, str, str, Dict[str, Any]
|
|
] # (source_id, target_id, relationship_name, properties)
|
|
Node = Tuple[str, NodeData] # (node_id, properties)
|
|
|
|
|
|
def record_graph_changes(func):
|
|
"""
|
|
Decorator to record graph changes in the relationship database.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- func: The asynchronous function to wrap, which likely modifies graph data.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the wrapped function that manages database relationships.
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(self, *args, **kwargs):
|
|
"""
|
|
Wraps the given asynchronous function to handle database relationships.
|
|
|
|
Tracks the caller's function and class name for context. When the wrapped function is
|
|
called, it manages database relationships for nodes or edges by adding entries to a
|
|
ledger and committing the changes to the database session. Errors during relationship
|
|
addition or session commit are logged and will not disrupt the execution of the wrapped
|
|
function.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- *args: Positional arguments passed to the wrapped function.
|
|
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
|
|
Returns:
|
|
--------
|
|
|
|
Returns the result of the wrapped function call.
|
|
"""
|
|
db_engine = get_relational_engine()
|
|
frame = inspect.currentframe()
|
|
while frame:
|
|
if frame.f_back and frame.f_back.f_code.co_name != "wrapper":
|
|
caller_frame = frame.f_back
|
|
break
|
|
frame = frame.f_back
|
|
|
|
caller_name = caller_frame.f_code.co_name
|
|
caller_class = (
|
|
caller_frame.f_locals.get("self", None).__class__.__name__
|
|
if caller_frame.f_locals.get("self", None)
|
|
else None
|
|
)
|
|
creator = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
|
|
result = await func(self, *args, **kwargs)
|
|
|
|
async with db_engine.get_async_session() as session:
|
|
if func.__name__ == "add_nodes":
|
|
nodes: List[DataPoint] = args[0]
|
|
|
|
relationship_ledgers = []
|
|
|
|
for node in nodes:
|
|
node_id = UUID(str(node.id))
|
|
relationship_ledgers.append(
|
|
GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=node_id,
|
|
destination_node_id=node_id,
|
|
creator_function=f"{creator}.node",
|
|
node_label=getattr(node, "name", None) or str(node.id),
|
|
)
|
|
)
|
|
|
|
try:
|
|
session.add_all(relationship_ledgers)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.debug(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
|
|
elif func.__name__ == "add_edges":
|
|
edges = args[0]
|
|
|
|
relationship_ledgers = []
|
|
|
|
for edge in edges:
|
|
source_id = UUID(str(edge[0]))
|
|
target_id = UUID(str(edge[1]))
|
|
rel_type = str(edge[2])
|
|
relationship_ledgers.append(
|
|
GraphRelationshipLedger(
|
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
|
source_node_id=source_id,
|
|
destination_node_id=target_id,
|
|
creator_function=f"{creator}.{rel_type}",
|
|
)
|
|
)
|
|
|
|
try:
|
|
session.add_all(relationship_ledgers)
|
|
await session.flush()
|
|
except Exception as e:
|
|
logger.debug(f"Error adding relationship: {e}")
|
|
await session.rollback()
|
|
|
|
try:
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.debug(f"Error committing session: {e}")
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
class GraphDBInterface(ABC):
|
|
"""
|
|
Define an interface for graph database operations to be implemented by concrete classes.
|
|
|
|
Public methods include:
|
|
- query
|
|
- add_node
|
|
- add_nodes
|
|
- delete_node
|
|
- delete_nodes
|
|
- get_node
|
|
- get_nodes
|
|
- add_edge
|
|
- add_edges
|
|
- delete_graph
|
|
- get_graph_data
|
|
- get_graph_metrics
|
|
- has_edge
|
|
- has_edges
|
|
- get_edges
|
|
- get_neighbors
|
|
- get_nodeset_subgraph
|
|
- get_connections
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def is_empty(self) -> bool:
|
|
logger.warning("is_empty() is not implemented")
|
|
return True
|
|
|
|
@abstractmethod
|
|
async def query(self, query: str, params: dict) -> List[Any]:
|
|
"""
|
|
Execute a raw database query and return the results.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- query (str): The query string to execute against the database.
|
|
- params (dict): A dictionary of parameters to be used in the query.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_node(
|
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""
|
|
Add a single node with specified properties to the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node (Union[DataPoint, str]): Either a DataPoint object or a string identifier for the node being added.
|
|
- properties (Optional[Dict[str, Any]]): A dictionary of properties associated with the node.
|
|
Required when node is a string, ignored when node is a DataPoint.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_nodes(self, nodes: Union[List[Node], List[DataPoint]]) -> None:
|
|
"""
|
|
Add multiple nodes to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- nodes (Union[List[Node], List[DataPoint]]): A list of Node objects or DataPoint objects to be added to the graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_node(self, node_id: str) -> None:
|
|
"""
|
|
Delete a specified node from the graph by its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier for the node to delete.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_nodes(self, node_ids: List[str]) -> None:
|
|
"""
|
|
Delete multiple nodes from the graph by their identifiers.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to delete.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_node(self, node_id: str) -> Optional[NodeData]:
|
|
"""
|
|
Retrieve a single node from the graph using its ID.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node to retrieve.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_nodes(self, node_ids: List[str]) -> List[NodeData]:
|
|
"""
|
|
Retrieve multiple nodes from the graph using their IDs.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_ids (List[str]): A list of unique identifiers for the nodes to retrieve.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def add_edge(
|
|
self,
|
|
source_id: str,
|
|
target_id: str,
|
|
relationship_name: str,
|
|
properties: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Create a new edge between two nodes in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): The unique identifier of the source node.
|
|
- target_id (str): The unique identifier of the target node.
|
|
- relationship_name (str): The name of the relationship to be established by the
|
|
edge.
|
|
- properties (Optional[Dict[str, Any]]): Optional dictionary of properties
|
|
associated with the edge. (default None)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@record_graph_changes
|
|
async def add_edges(
|
|
self, edges: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]
|
|
) -> None:
|
|
"""
|
|
Add multiple edges to the graph in a single operation.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]]): A list of EdgeData objects or tuples representing edges to be added.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def delete_graph(self) -> None:
|
|
"""
|
|
Remove the entire graph, including all nodes and edges.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
|
|
"""
|
|
Retrieve all nodes and edges within the graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Fetch metrics and statistics of the graph, possibly including optional details.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- include_optional (bool): Flag indicating whether to include optional metrics or
|
|
not. (default False)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
|
"""
|
|
Verify if an edge exists between two specified nodes.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- source_id (str): Unique identifier of the source node.
|
|
- target_id (str): Unique identifier of the target node.
|
|
- relationship_name (str): Name of the relationship to verify.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def has_edges(self, edges: List[EdgeData]) -> List[EdgeData]:
|
|
"""
|
|
Determine the existence of multiple edges in the graph.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- edges (List[EdgeData]): A list of EdgeData objects to check for existence in the
|
|
graph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_edges(self, node_id: str) -> List[EdgeData]:
|
|
"""
|
|
Retrieve all edges that are connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node whose edges are to be retrieved.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_neighbors(self, node_id: str) -> List[NodeData]:
|
|
"""
|
|
Get all neighboring nodes connected to the specified node.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (str): Unique identifier of the node for which to retrieve neighbors.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_nodeset_subgraph(
|
|
self, node_type: Type[Any], node_name: List[str]
|
|
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
|
|
"""
|
|
Fetch a subgraph consisting of a specific set of nodes and their relationships.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_type (Type[Any]): The type of nodes to include in the subgraph.
|
|
- node_name (List[str]): A list of names of the nodes to include in the subgraph.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_connections(
|
|
self, node_id: Union[str, UUID]
|
|
) -> List[Tuple[NodeData, Dict[str, Any], NodeData]]:
|
|
"""
|
|
Get all nodes connected to a specified node and their relationship details.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
async def get_filtered_graph_data(
|
|
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
|
) -> Tuple[List[Node], List[EdgeData]]:
|
|
"""
|
|
Retrieve nodes and edges filtered by the provided attribute criteria.
|
|
|
|
Parameters:
|
|
-----------
|
|
|
|
- attribute_filters: A list of dictionaries where keys are attribute names and values
|
|
are lists of attribute values to filter by.
|
|
"""
|
|
raise NotImplementedError
|