fix memgraph_impl.py according to test_graph_storage.py
This commit is contained in:
parent
9aaa7d2dd3
commit
c0a3638d01
2 changed files with 52 additions and 45 deletions
|
|
@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
node_dict = dict(node)
|
||||
node_dict["id"] = node_dict.get("entity_id")
|
||||
nodes.append(node_dict)
|
||||
await result.consume()
|
||||
return nodes
|
||||
await result.consume()
|
||||
return nodes
|
||||
|
||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||
"""Get all edges that are associated with the given chunk_ids.
|
||||
|
|
@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
UNWIND $chunk_ids AS chunk_id
|
||||
MATCH (a:base)-[r]-(b:base)
|
||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
||||
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
||||
// Ensure we only return each unique edge once by ordering the source and target
|
||||
WITH a, b, r,
|
||||
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
||||
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
||||
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
||||
"""
|
||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||
edges = []
|
||||
|
|
@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
try:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
try:
|
||||
if node_label == "*":
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_result = None
|
||||
|
|
@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
logger.debug(f"No record found for node {node_label}")
|
||||
return result
|
||||
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
seen_nodes.add(node_id)
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
seen_edges.add(edge_id)
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting knowledge graph: {str(e)}")
|
||||
return result
|
||||
if record:
|
||||
for node_info in record["node_info"]:
|
||||
node = node_info["node"]
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
seen_nodes.add(node_id)
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=[node.get("entity_id")],
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
|
||||
for rel in record["relationships"]:
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
seen_edges.add(edge_id)
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting knowledge graph: {str(e)}")
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
- NetworkXStorage
|
||||
- Neo4JStorage
|
||||
- PGGraphStorage
|
||||
- MemgraphStorage
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue