fix memgraph_impl.py according to test_graph_storage.py
This commit is contained in:
parent
9aaa7d2dd3
commit
c0a3638d01
2 changed files with 52 additions and 45 deletions
|
|
@ -594,8 +594,8 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
node_dict = dict(node)
|
node_dict = dict(node)
|
||||||
node_dict["id"] = node_dict.get("entity_id")
|
node_dict["id"] = node_dict.get("entity_id")
|
||||||
nodes.append(node_dict)
|
nodes.append(node_dict)
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
||||||
"""Get all edges that are associated with the given chunk_ids.
|
"""Get all edges that are associated with the given chunk_ids.
|
||||||
|
|
@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
UNWIND $chunk_ids AS chunk_id
|
UNWIND $chunk_ids AS chunk_id
|
||||||
MATCH (a:base)-[r]-(b:base)
|
MATCH (a:base)-[r]-(b:base)
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
||||||
|
// Ensure we only return each unique edge once by ordering the source and target
|
||||||
|
WITH a, b, r,
|
||||||
|
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
||||||
|
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
||||||
|
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
||||||
"""
|
"""
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||||
edges = []
|
edges = []
|
||||||
|
|
@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
try:
|
async with self._driver.session(
|
||||||
async with self._driver.session(
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
) as session:
|
||||||
) as session:
|
try:
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
count_query = "MATCH (n) RETURN count(n) as total"
|
count_query = "MATCH (n) RETURN count(n) as total"
|
||||||
count_result = None
|
count_result = None
|
||||||
|
|
@ -736,45 +741,46 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.debug(f"No record found for node {node_label}")
|
logger.debug(f"No record found for node {node_label}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
for node_info in record["node_info"]:
|
|
||||||
node = node_info["node"]
|
|
||||||
node_id = node.id
|
|
||||||
if node_id not in seen_nodes:
|
|
||||||
seen_nodes.add(node_id)
|
|
||||||
result.nodes.append(
|
|
||||||
KnowledgeGraphNode(
|
|
||||||
id=f"{node_id}",
|
|
||||||
labels=[node.get("entity_id")],
|
|
||||||
properties=dict(node),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for rel in record["relationships"]:
|
|
||||||
edge_id = rel.id
|
|
||||||
if edge_id not in seen_edges:
|
|
||||||
seen_edges.add(edge_id)
|
|
||||||
start = rel.start_node
|
|
||||||
end = rel.end_node
|
|
||||||
result.edges.append(
|
|
||||||
KnowledgeGraphEdge(
|
|
||||||
id=f"{edge_id}",
|
|
||||||
type=rel.type,
|
|
||||||
source=f"{start.id}",
|
|
||||||
target=f"{end.id}",
|
|
||||||
properties=dict(rel),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if result_set:
|
if result_set:
|
||||||
await result_set.consume()
|
await result_set.consume()
|
||||||
|
|
||||||
except Exception as e:
|
if record:
|
||||||
logger.error(f"Error getting knowledge graph: {str(e)}")
|
for node_info in record["node_info"]:
|
||||||
return result
|
node = node_info["node"]
|
||||||
|
node_id = node.id
|
||||||
|
if node_id not in seen_nodes:
|
||||||
|
seen_nodes.add(node_id)
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=f"{node_id}",
|
||||||
|
labels=[node.get("entity_id")],
|
||||||
|
properties=dict(node),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for rel in record["relationships"]:
|
||||||
|
edge_id = rel.id
|
||||||
|
if edge_id not in seen_edges:
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
start = rel.start_node
|
||||||
|
end = rel.end_node
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=f"{edge_id}",
|
||||||
|
type=rel.type,
|
||||||
|
source=f"{start.id}",
|
||||||
|
target=f"{end.id}",
|
||||||
|
properties=dict(rel),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting knowledge graph: {str(e)}")
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
- NetworkXStorage
|
- NetworkXStorage
|
||||||
- Neo4JStorage
|
- Neo4JStorage
|
||||||
- PGGraphStorage
|
- PGGraphStorage
|
||||||
|
- MemgraphStorage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue