feat: Adds graph node filtering by feature
This commit is contained in:
parent
d9eec77f18
commit
0101d43b8d
3 changed files with 90 additions and 4 deletions
|
|
@ -2,7 +2,7 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Optional, Any, List, Dict
|
from typing import Optional, Any, List, Dict, Union
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from neo4j import AsyncSession
|
from neo4j import AsyncSession
|
||||||
|
|
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
) for record in result]
|
) for record in result]
|
||||||
|
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
|
async def get_filtered_graph_data(self, attribute_filters):
|
||||||
|
"""
|
||||||
|
Fetches nodes and relationships filtered by specified attribute values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||||
|
Example: [{"community": ["1", "2"]}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing two lists: nodes and edges.
|
||||||
|
"""
|
||||||
|
where_clauses = []
|
||||||
|
for attribute, values in attribute_filters[0].items():
|
||||||
|
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
|
||||||
|
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||||||
|
|
||||||
|
where_clause = " AND ".join(where_clauses)
|
||||||
|
|
||||||
|
query_nodes = f"""
|
||||||
|
MATCH (n)
|
||||||
|
WHERE {where_clause}
|
||||||
|
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
||||||
|
"""
|
||||||
|
result_nodes = await self.query(query_nodes)
|
||||||
|
|
||||||
|
nodes = [(
|
||||||
|
record["id"],
|
||||||
|
record["properties"],
|
||||||
|
) for record in result_nodes]
|
||||||
|
|
||||||
|
query_edges = f"""
|
||||||
|
MATCH (n)-[r]->(m)
|
||||||
|
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
|
||||||
|
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||||
|
"""
|
||||||
|
result_edges = await self.query(query_edges)
|
||||||
|
|
||||||
|
edges = [(
|
||||||
|
record["source"],
|
||||||
|
record["target"],
|
||||||
|
record["type"],
|
||||||
|
record["properties"],
|
||||||
|
) for record in result_edges]
|
||||||
|
|
||||||
|
return (nodes, edges)
|
||||||
|
|
@ -6,7 +6,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from re import A
|
from re import A
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os as aiofiles_os
|
import aiofiles.os as aiofiles_os
|
||||||
|
|
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
logger.info("Graph deleted successfully.")
|
logger.info("Graph deleted successfully.")
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error("Failed to delete graph: %s", error)
|
logger.error("Failed to delete graph: %s", error)
|
||||||
|
|
||||||
|
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
|
||||||
|
"""
|
||||||
|
Fetches nodes and relationships filtered by specified attribute values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||||
|
Example: [{"community": ["1", "2"]}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing two lists:
|
||||||
|
- Nodes: List of tuples (node_id, node_properties).
|
||||||
|
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
|
||||||
|
"""
|
||||||
|
# Create filters for nodes based on the attribute filters
|
||||||
|
where_clauses = []
|
||||||
|
for attribute, values in attribute_filters[0].items():
|
||||||
|
where_clauses.append((attribute, values))
|
||||||
|
|
||||||
|
# Filter nodes
|
||||||
|
filtered_nodes = [
|
||||||
|
(node, data) for node, data in self.graph.nodes(data=True)
|
||||||
|
if all(data.get(attr) in values for attr, values in where_clauses)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Filter edges where both source and target nodes satisfy the filters
|
||||||
|
filtered_edges = [
|
||||||
|
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
|
||||||
|
for source, target, data in self.graph.edges(data=True)
|
||||||
|
if (
|
||||||
|
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
|
||||||
|
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_nodes, filtered_edges
|
||||||
|
|
@ -52,13 +52,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
edge_properties_to_project: List[str],
|
edge_properties_to_project: List[str],
|
||||||
directed = True,
|
directed = True,
|
||||||
node_dimension = 1,
|
node_dimension = 1,
|
||||||
edge_dimension = 1) -> None:
|
edge_dimension = 1,
|
||||||
|
memory_fragment_filter = List[Dict[str, List[Union[str, int]]]]) -> None:
|
||||||
|
|
||||||
if node_dimension < 1 or edge_dimension < 1:
|
if node_dimension < 1 or edge_dimension < 1:
|
||||||
raise ValueError("Dimensions must be positive integers")
|
raise ValueError("Dimensions must be positive integers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nodes_data, edges_data = await adapter.get_graph_data()
|
if len(memory_fragment_filter) == 0:
|
||||||
|
nodes_data, edges_data = await adapter.get_graph_data()
|
||||||
|
else:
|
||||||
|
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
||||||
|
|
||||||
if not nodes_data:
|
if not nodes_data:
|
||||||
raise ValueError("No node data retrieved from the database.")
|
raise ValueError("No node data retrieved from the database.")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue