fix: Fixes collection search limit in brute force triplet search (#814)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
34b95b687c
commit
a78fec3a91
4 changed files with 39 additions and 36 deletions
8
.github/workflows/test_memgraph.yml
vendored
8
.github/workflows/test_memgraph.yml
vendored
|
|
@ -1,9 +1,9 @@
|
||||||
name: test | memgraph
|
name: test | memgraph
|
||||||
|
|
||||||
on:
|
# on:
|
||||||
workflow_dispatch:
|
# workflow_dispatch:
|
||||||
pull_request:
|
# pull_request:
|
||||||
types: [labeled, synchronize]
|
# types: [labeled, synchronize]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
|
||||||
logger = get_logger("MemgraphAdapter", level=ERROR)
|
logger = get_logger("MemgraphAdapter", level=ERROR)
|
||||||
|
|
||||||
|
|
||||||
class MemgraphAdapter(GraphDBInterface):
|
class MemgraphAdapter(GraphDBInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -34,7 +35,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
async def get_session(self) -> AsyncSession:
|
async def get_session(self) -> AsyncSession:
|
||||||
async with self.driver.session() as session:
|
async with self.driver.session() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -48,7 +49,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
results = await self.query(
|
results = await self.query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -59,7 +60,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
{"node_id": node_id},
|
{"node_id": node_id},
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node: DataPoint):
|
async def add_node(self, node: DataPoint):
|
||||||
serialized_properties = self.serialize_properties(node.model_dump())
|
serialized_properties = self.serialize_properties(node.model_dump())
|
||||||
|
|
||||||
|
|
@ -102,7 +103,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
results = await self.extract_nodes([node_id])
|
results = await self.extract_nodes([node_id])
|
||||||
|
|
||||||
return results[0] if len(results) > 0 else None
|
return results[0] if len(results) > 0 else None
|
||||||
|
|
||||||
async def extract_nodes(self, node_ids: List[str]):
|
async def extract_nodes(self, node_ids: List[str]):
|
||||||
query = """
|
query = """
|
||||||
UNWIND $node_ids AS id
|
UNWIND $node_ids AS id
|
||||||
|
|
@ -114,15 +115,15 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
results = await self.query(query, params)
|
results = await self.query(query, params)
|
||||||
|
|
||||||
return [result["node"] for result in results]
|
return [result["node"] for result in results]
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str):
|
||||||
sanitized_id = node_id.replace(":", "_")
|
sanitized_id = node_id.replace(":", "_")
|
||||||
|
|
||||||
query = "MATCH (node: {{id: $node_id}}) DETACH DELETE node"
|
query = "MATCH (node: {{id: $node_id}}) DETACH DELETE node"
|
||||||
params = {"node_id": sanitized_id}
|
params = {"node_id": sanitized_id}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||||
query = """
|
query = """
|
||||||
UNWIND $node_ids AS id
|
UNWIND $node_ids AS id
|
||||||
|
|
@ -132,7 +133,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||||
query = """
|
query = """
|
||||||
MATCH (from_node)-[relationship]->(to_node)
|
MATCH (from_node)-[relationship]->(to_node)
|
||||||
|
|
@ -145,10 +146,10 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
"to_node_id": str(to_node),
|
"to_node_id": str(to_node),
|
||||||
"edge_label": edge_label,
|
"edge_label": edge_label,
|
||||||
}
|
}
|
||||||
|
|
||||||
records = await self.query(query, params)
|
records = await self.query(query, params)
|
||||||
return records[0]["edge_exists"] if records else False
|
return records[0]["edge_exists"] if records else False
|
||||||
|
|
||||||
async def has_edges(self, edges):
|
async def has_edges(self, edges):
|
||||||
query = """
|
query = """
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
|
|
@ -174,7 +175,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: UUID,
|
from_node: UUID,
|
||||||
|
|
@ -203,7 +204,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||||
query = """
|
query = """
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
|
|
@ -217,7 +218,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
target_node_id: edge.to_node
|
target_node_id: edge.to_node
|
||||||
},
|
},
|
||||||
edge.properties,
|
edge.properties,
|
||||||
to_node,
|
to_node,
|
||||||
{}
|
{}
|
||||||
) YIELD rel
|
) YIELD rel
|
||||||
RETURN rel"""
|
RETURN rel"""
|
||||||
|
|
@ -242,7 +243,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def get_edges(self, node_id: str):
|
async def get_edges(self, node_id: str):
|
||||||
query = """
|
query = """
|
||||||
MATCH (n {id: $node_id})-[r]-(m)
|
MATCH (n {id: $node_id})-[r]-(m)
|
||||||
|
|
@ -255,7 +256,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
|
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_disconnected_nodes(self) -> list[str]:
|
async def get_disconnected_nodes(self) -> list[str]:
|
||||||
query = """
|
query = """
|
||||||
// Step 1: Collect all nodes
|
// Step 1: Collect all nodes
|
||||||
|
|
@ -290,7 +291,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
|
|
||||||
results = await self.query(query)
|
results = await self.query(query)
|
||||||
return results[0]["ids"] if len(results) > 0 else []
|
return results[0]["ids"] if len(results) > 0 else []
|
||||||
|
|
||||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||||
if edge_label is not None:
|
if edge_label is not None:
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -323,7 +324,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["predecessor"] for result in results]
|
return [result["predecessor"] for result in results]
|
||||||
|
|
||||||
async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
|
async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||||
if edge_label is not None:
|
if edge_label is not None:
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -356,14 +357,14 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
return [result["successor"] for result in results]
|
return [result["successor"] for result in results]
|
||||||
|
|
||||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||||
predecessors, successors = await asyncio.gather(
|
predecessors, successors = await asyncio.gather(
|
||||||
self.get_predecessors(node_id), self.get_successors(node_id)
|
self.get_predecessors(node_id), self.get_successors(node_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
return predecessors + successors
|
return predecessors + successors
|
||||||
|
|
||||||
async def get_connections(self, node_id: UUID) -> list:
|
async def get_connections(self, node_id: UUID) -> list:
|
||||||
predecessors_query = """
|
predecessors_query = """
|
||||||
MATCH (node)<-[relation]-(neighbour)
|
MATCH (node)<-[relation]-(neighbour)
|
||||||
|
|
@ -392,7 +393,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||||
|
|
||||||
return connections
|
return connections
|
||||||
|
|
||||||
async def remove_connection_to_predecessors_of(
|
async def remove_connection_to_predecessors_of(
|
||||||
self, node_ids: list[str], edge_label: str
|
self, node_ids: list[str], edge_label: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -406,7 +407,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
params = {"node_ids": node_ids, "edge_label": edge_label}
|
params = {"node_ids": node_ids, "edge_label": edge_label}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def remove_connection_to_successors_of(
|
async def remove_connection_to_successors_of(
|
||||||
self, node_ids: list[str], edge_label: str
|
self, node_ids: list[str], edge_label: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -419,13 +420,13 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
|
|
||||||
return await self.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
async def delete_graph(self):
|
async def delete_graph(self):
|
||||||
query = """MATCH (node)
|
query = """MATCH (node)
|
||||||
DETACH DELETE node;"""
|
DETACH DELETE node;"""
|
||||||
|
|
||||||
return await self.query(query)
|
return await self.query(query)
|
||||||
|
|
||||||
def serialize_properties(self, properties=dict()):
|
def serialize_properties(self, properties=dict()):
|
||||||
serialized_properties = {}
|
serialized_properties = {}
|
||||||
|
|
||||||
|
|
@ -441,7 +442,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
serialized_properties[property_key] = property_value
|
serialized_properties[property_key] = property_value
|
||||||
|
|
||||||
return serialized_properties
|
return serialized_properties
|
||||||
|
|
||||||
async def get_model_independent_graph_data(self):
|
async def get_model_independent_graph_data(self):
|
||||||
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
|
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
|
||||||
nodes = await self.query(query_nodes)
|
nodes = await self.query(query_nodes)
|
||||||
|
|
@ -450,7 +451,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
edges = await self.query(query_edges)
|
edges = await self.query(query_edges)
|
||||||
|
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
async def get_graph_data(self):
|
async def get_graph_data(self):
|
||||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||||
|
|
||||||
|
|
@ -480,7 +481,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
]
|
]
|
||||||
|
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
async def get_filtered_graph_data(self, attribute_filters):
|
async def get_filtered_graph_data(self, attribute_filters):
|
||||||
"""
|
"""
|
||||||
Fetches nodes and relationships filtered by specified attribute values.
|
Fetches nodes and relationships filtered by specified attribute values.
|
||||||
|
|
@ -536,7 +537,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
async def get_node_labels_string(self):
|
async def get_node_labels_string(self):
|
||||||
node_labels_query = f"""
|
node_labels_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WITH DISTINCT labels(n) AS labelList
|
WITH DISTINCT labels(n) AS labelList
|
||||||
UNWIND labelList AS label
|
UNWIND labelList AS label
|
||||||
|
|
@ -552,7 +553,9 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
return node_labels_str
|
return node_labels_str
|
||||||
|
|
||||||
async def get_relationship_labels_string(self):
|
async def get_relationship_labels_string(self):
|
||||||
relationship_types_query = "MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
|
relationship_types_query = (
|
||||||
|
"MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
|
||||||
|
)
|
||||||
relationship_types_result = await self.query(relationship_types_query)
|
relationship_types_result = await self.query(relationship_types_query)
|
||||||
relationship_types = (
|
relationship_types = (
|
||||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||||
|
|
@ -643,7 +646,7 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
WITH n, degree, COUNT(n2) AS triangle_count
|
WITH n, degree, COUNT(n2) AS triangle_count
|
||||||
|
|
||||||
// Step 4: Compute local clustering coefficient
|
// Step 4: Compute local clustering coefficient
|
||||||
WITH n, degree,
|
WITH n, degree,
|
||||||
CASE WHEN degree <= 1 THEN 0.0
|
CASE WHEN degree <= 1 THEN 0.0
|
||||||
ELSE (1.0 * triangle_count) / (degree * (degree - 1) / 2.0)
|
ELSE (1.0 * triangle_count) / (degree * (degree - 1) / 2.0)
|
||||||
END AS local_cc
|
END AS local_cc
|
||||||
|
|
@ -684,4 +687,4 @@ class MemgraphAdapter(GraphDBInterface):
|
||||||
"diameter": -1,
|
"diameter": -1,
|
||||||
"avg_shortest_path_length": -1,
|
"avg_shortest_path_length": -1,
|
||||||
"avg_clustering": -1,
|
"avg_clustering": -1,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ async def brute_force_search(
|
||||||
async def search_in_collection(collection_name: str):
|
async def search_in_collection(collection_name: str):
|
||||||
try:
|
try:
|
||||||
return await vector_engine.search(
|
return await vector_engine.search(
|
||||||
collection_name=collection_name, query_text=query, limit=top_k
|
collection_name=collection_name, query_text=query, limit=0
|
||||||
)
|
)
|
||||||
except CollectionNotFoundError:
|
except CollectionNotFoundError:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ async def main():
|
||||||
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
|
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue