feat: Adds graph node filtering by feature
This commit is contained in:
parent
d9eec77f18
commit
0101d43b8d
3 changed files with 90 additions and 4 deletions
|
|
@ -2,7 +2,7 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Union
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
) for record in result]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
||||
Args:
|
||||
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||
Example: [{"community": ["1", "2"]}]
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists: nodes and edges.
|
||||
"""
|
||||
where_clauses = []
|
||||
for attribute, values in attribute_filters[0].items():
|
||||
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
|
||||
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||||
|
||||
where_clause = " AND ".join(where_clauses)
|
||||
|
||||
query_nodes = f"""
|
||||
MATCH (n)
|
||||
WHERE {where_clause}
|
||||
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
||||
"""
|
||||
result_nodes = await self.query(query_nodes)
|
||||
|
||||
nodes = [(
|
||||
record["id"],
|
||||
record["properties"],
|
||||
) for record in result_nodes]
|
||||
|
||||
query_edges = f"""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result_edges = await self.query(query_edges)
|
||||
|
||||
edges = [(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
) for record in result_edges]
|
||||
|
||||
return (nodes, edges)
|
||||
|
|
@ -6,7 +6,7 @@ import json
|
|||
import asyncio
|
||||
import logging
|
||||
from re import A
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Union
|
||||
from uuid import UUID
|
||||
import aiofiles
|
||||
import aiofiles.os as aiofiles_os
|
||||
|
|
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
logger.info("Graph deleted successfully.")
|
||||
except Exception as error:
|
||||
logger.error("Failed to delete graph: %s", error)
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
||||
Args:
|
||||
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||
Example: [{"community": ["1", "2"]}]
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists:
|
||||
- Nodes: List of tuples (node_id, node_properties).
|
||||
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
|
||||
"""
|
||||
# Create filters for nodes based on the attribute filters
|
||||
where_clauses = []
|
||||
for attribute, values in attribute_filters[0].items():
|
||||
where_clauses.append((attribute, values))
|
||||
|
||||
# Filter nodes
|
||||
filtered_nodes = [
|
||||
(node, data) for node, data in self.graph.nodes(data=True)
|
||||
if all(data.get(attr) in values for attr, values in where_clauses)
|
||||
]
|
||||
|
||||
# Filter edges where both source and target nodes satisfy the filters
|
||||
filtered_edges = [
|
||||
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
|
||||
for source, target, data in self.graph.edges(data=True)
|
||||
if (
|
||||
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
|
||||
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
|
||||
)
|
||||
]
|
||||
|
||||
return filtered_nodes, filtered_edges
|
||||
|
|
@ -52,13 +52,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
edge_properties_to_project: List[str],
|
||||
directed = True,
|
||||
node_dimension = 1,
|
||||
edge_dimension = 1) -> None:
|
||||
edge_dimension = 1,
|
||||
memory_fragment_filter = List[Dict[str, List[Union[str, int]]]]) -> None:
|
||||
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise ValueError("Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
if len(memory_fragment_filter) == 0:
|
||||
nodes_data, edges_data = await adapter.get_graph_data()
|
||||
else:
|
||||
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
||||
|
||||
if not nodes_data:
|
||||
raise ValueError("No node data retrieved from the database.")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue