feat: Adds graph node filtering by feature

This commit is contained in:
hajdul88 2024-11-20 15:13:38 +01:00
parent d9eec77f18
commit 0101d43b8d
3 changed files with 90 additions and 4 deletions

View file

@ -2,7 +2,7 @@
import logging
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Union
from contextlib import asynccontextmanager
from uuid import UUID
from neo4j import AsyncSession
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
) for record in result]
return (nodes, edges)
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists: nodes and edges.
"""
where_clauses = []
for attribute, values in attribute_filters[0].items():
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
where_clauses.append(f"n.{attribute} IN [{values_str}]")
where_clause = " AND ".join(where_clauses)
query_nodes = f"""
MATCH (n)
WHERE {where_clause}
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
"""
result_nodes = await self.query(query_nodes)
nodes = [(
record["id"],
record["properties"],
) for record in result_nodes]
query_edges = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result_edges = await self.query(query_edges)
edges = [(
record["source"],
record["target"],
record["type"],
record["properties"],
) for record in result_edges]
return (nodes, edges)

View file

@ -6,7 +6,7 @@ import json
import asyncio
import logging
from re import A
from typing import Dict, Any, List
from typing import Dict, Any, List, Union
from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
logger.info("Graph deleted successfully.")
except Exception as error:
logger.error("Failed to delete graph: %s", error)
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists:
- Nodes: List of tuples (node_id, node_properties).
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
"""
# Create filters for nodes based on the attribute filters
where_clauses = []
for attribute, values in attribute_filters[0].items():
where_clauses.append((attribute, values))
# Filter nodes
filtered_nodes = [
(node, data) for node, data in self.graph.nodes(data=True)
if all(data.get(attr) in values for attr, values in where_clauses)
]
# Filter edges where both source and target nodes satisfy the filters
filtered_edges = [
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
for source, target, data in self.graph.edges(data=True)
if (
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
)
]
return filtered_nodes, filtered_edges

View file

@ -52,13 +52,17 @@ class CogneeGraph(CogneeAbstractGraph):
edge_properties_to_project: List[str],
directed = True,
node_dimension = 1,
edge_dimension = 1) -> None:
edge_dimension = 1,
memory_fragment_filter = List[Dict[str, List[Union[str, int]]]]) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers")
try:
nodes_data, edges_data = await adapter.get_graph_data()
if len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
if not nodes_data:
raise ValueError("No node data retrieved from the database.")