diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 1121a24d5..e6520e4e2 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -2,7 +2,7 @@ import logging import asyncio from textwrap import dedent -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Union from contextlib import asynccontextmanager from uuid import UUID from neo4j import AsyncSession @@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface): ) for record in result] return (nodes, edges) + + async def get_filtered_graph_data(self, attribute_filters): + """ + Fetches nodes and relationships filtered by specified attribute values. + + Args: + attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on. + Example: [{"community": ["1", "2"]}] + + Returns: + tuple: A tuple containing two lists: nodes and edges. + """ + where_clauses = [] + for attribute, values in attribute_filters[0].items(): + values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values) + where_clauses.append(f"n.{attribute} IN [{values_str}]") + + where_clause = " AND ".join(where_clauses) + + query_nodes = f""" + MATCH (n) + WHERE {where_clause} + RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties + """ + result_nodes = await self.query(query_nodes) + + nodes = [( + record["id"], + record["properties"], + ) for record in result_nodes] + + query_edges = f""" + MATCH (n)-[r]->(m) + WHERE {where_clause} AND {where_clause.replace('n.', 'm.')} + RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties + """ + result_edges = await self.query(query_edges) + + edges = [( + record["source"], + record["target"], + record["type"], + record["properties"], + ) for record in result_edges] + + return (nodes, edges) \ No newline at end of file diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index a72376082..d249b6336 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -6,7 +6,7 @@ import json import asyncio import logging from re import A -from typing import Dict, Any, List +from typing import Dict, Any, List, Union from uuid import UUID import aiofiles import aiofiles.os as aiofiles_os @@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface): logger.info("Graph deleted successfully.") except Exception as error: logger.error("Failed to delete graph: %s", error) + + async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]): + """ + Fetches nodes and relationships filtered by specified attribute values. + + Args: + attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on. + Example: [{"community": ["1", "2"]}] + + Returns: + tuple: A tuple containing two lists: + - Nodes: List of tuples (node_id, node_properties). + - Edges: List of tuples (source_id, target_id, relationship_type, edge_properties). + """ + # Create filters for nodes based on the attribute filters + where_clauses = [] + for attribute, values in attribute_filters[0].items(): + where_clauses.append((attribute, values)) + + # Filter nodes + filtered_nodes = [ + (node, data) for node, data in self.graph.nodes(data=True) + if all(data.get(attr) in values for attr, values in where_clauses) + ] + + # Filter edges where both source and target nodes satisfy the filters + filtered_edges = [ + (source, target, data.get('relationship_type', 'UNKNOWN'), data) + for source, target, data in self.graph.edges(data=True) + if ( + all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and + all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses) + ) + ] + + return filtered_nodes, filtered_edges \ No newline at end of file diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index d15d93b73..0b752c6cb 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -52,13 +52,17 @@ class CogneeGraph(CogneeAbstractGraph): edge_properties_to_project: List[str], directed = True, node_dimension = 1, - edge_dimension = 1) -> None: + edge_dimension = 1, + memory_fragment_filter = List[Dict[str, List[Union[str, int]]]]) -> None: if node_dimension < 1 or edge_dimension < 1: raise ValueError("Dimensions must be positive integers") try: - nodes_data, edges_data = await adapter.get_graph_data() + if len(memory_fragment_filter) == 0: + nodes_data, edges_data = await adapter.get_graph_data() + else: + nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter) if not nodes_data: raise ValueError("No node data retrieved from the database.")