fix memgraph_impl.py according to test_graph_storage.py
This commit is contained in:
parent
9aaa7d2dd3
commit
c0a3638d01
2 changed files with 52 additions and 45 deletions
|
|
@ -614,7 +614,12 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
UNWIND $chunk_ids AS chunk_id
|
UNWIND $chunk_ids AS chunk_id
|
||||||
MATCH (a:base)-[r]-(b:base)
|
MATCH (a:base)-[r]-(b:base)
|
||||||
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
|
||||||
RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties
|
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
|
||||||
|
// Ensure we only return each unique edge once by ordering the source and target
|
||||||
|
WITH a, b, r,
|
||||||
|
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
|
||||||
|
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
|
||||||
|
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
|
||||||
"""
|
"""
|
||||||
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
|
||||||
edges = []
|
edges = []
|
||||||
|
|
@ -650,10 +655,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
try:
|
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
try:
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
count_query = "MATCH (n) RETURN count(n) as total"
|
count_query = "MATCH (n) RETURN count(n) as total"
|
||||||
count_result = None
|
count_result = None
|
||||||
|
|
@ -736,6 +741,11 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.debug(f"No record found for node {node_label}")
|
logger.debug(f"No record found for node {node_label}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if result_set:
|
||||||
|
await result_set.consume()
|
||||||
|
|
||||||
|
if record:
|
||||||
for node_info in record["node_info"]:
|
for node_info in record["node_info"]:
|
||||||
node = node_info["node"]
|
node = node_info["node"]
|
||||||
node_id = node.id
|
node_id = node.id
|
||||||
|
|
@ -771,10 +781,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
finally:
|
|
||||||
if result_set:
|
|
||||||
await result_set.consume()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting knowledge graph: {str(e)}")
|
logger.error(f"Error getting knowledge graph: {str(e)}")
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
- NetworkXStorage
|
- NetworkXStorage
|
||||||
- Neo4JStorage
|
- Neo4JStorage
|
||||||
- PGGraphStorage
|
- PGGraphStorage
|
||||||
|
- MemgraphStorage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue