fix memgraph_impl.py according to test_graph_storage.py

This commit is contained in:
DavIvek 2025-06-27 15:35:20 +02:00
parent 9aaa7d2dd3
commit c0a3638d01
2 changed files with 52 additions and 45 deletions

View file

@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage):
UNWIND $chunk_ids AS chunk_id UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base) MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
""" """
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = [] edges = []
@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage):
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
try:
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
try:
if node_label == "*": if node_label == "*":
count_query = "MATCH (n) RETURN count(n) as total" count_query = "MATCH (n) RETURN count(n) as total"
count_result = None count_result = None
@ -736,6 +741,11 @@ class MemgraphStorage(BaseGraphStorage):
logger.debug(f"No record found for node {node_label}") logger.debug(f"No record found for node {node_label}")
return result return result
finally:
if result_set:
await result_set.consume()
if record:
for node_info in record["node_info"]: for node_info in record["node_info"]:
node = node_info["node"] node = node_info["node"]
node_id = node.id node_id = node.id
@ -771,10 +781,6 @@ class MemgraphStorage(BaseGraphStorage):
return result return result
finally:
if result_set:
await result_set.consume()
except Exception as e: except Exception as e:
logger.error(f"Error getting knowledge graph: {str(e)}") logger.error(f"Error getting knowledge graph: {str(e)}")
return result return result

View file

@ -9,6 +9,7 @@
- NetworkXStorage - NetworkXStorage
- Neo4JStorage - Neo4JStorage
- PGGraphStorage - PGGraphStorage
- MemgraphStorage
""" """
import asyncio import asyncio