fix memgraph_impl.py according to test_graph_storage.py

This commit is contained in:
DavIvek 2025-06-27 15:35:20 +02:00
parent 9aaa7d2dd3
commit c0a3638d01
2 changed files with 52 additions and 45 deletions

View file

@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage):
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage):
UNWIND $chunk_ids AS chunk_id
MATCH (a:base)-[r]-(b:base)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage):
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
if node_label == "*":
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage):
logger.debug(f"No record found for node {node_label}")
return result
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
seen_edges.add(edge_id)
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
finally:
if result_set:
await result_set.consume()
except Exception as e:
logger.error(f"Error getting knowledge graph: {str(e)}")
return result
if record:
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
seen_edges.add(edge_id)
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
except Exception as e:
logger.error(f"Error getting knowledge graph: {str(e)}")
return result

View file

@ -9,6 +9,7 @@
- NetworkXStorage
- Neo4JStorage
- PGGraphStorage
- MemgraphStorage
"""
import asyncio