quickstart working with memgraph
This commit is contained in:
parent
b0d0041429
commit
1641b9c1c1
6 changed files with 41 additions and 21 deletions
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from neo4j import GraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
@ -31,7 +31,7 @@ class MemgraphDriver(GraphDriver):
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = GraphDatabase.driver(
|
self.client = AsyncGraphDatabase.driver(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
auth=(user or '', password or ''),
|
auth=(user or '', password or ''),
|
||||||
)
|
)
|
||||||
|
|
@ -57,15 +57,11 @@ class MemgraphDriver(GraphDriver):
|
||||||
database = kwargs.pop('database_', self._database)
|
database = kwargs.pop('database_', self._database)
|
||||||
kwargs.pop('parameters_', None) # Remove if present
|
kwargs.pop('parameters_', None) # Remove if present
|
||||||
|
|
||||||
with self.client.session(database=database) as session:
|
async with self.client.session(database=database) as session:
|
||||||
try:
|
try:
|
||||||
# Debug: Print the query and parameters
|
result = await session.run(cypher_query_, params)
|
||||||
print(f"DEBUG - Memgraph Query: {cypher_query_}")
|
records = [record async for record in result]
|
||||||
print(f"DEBUG - Memgraph Params: {params}")
|
summary = await result.consume()
|
||||||
|
|
||||||
result = session.run(cypher_query_, params)
|
|
||||||
records = list(result)
|
|
||||||
summary = result.consume()
|
|
||||||
keys = result.keys()
|
keys = result.keys()
|
||||||
return (records, summary, keys)
|
return (records, summary, keys)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -77,7 +73,7 @@ class MemgraphDriver(GraphDriver):
|
||||||
return self.client.session(database=_database) # type: ignore
|
return self.client.session(database=_database) # type: ignore
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
return self.client.close()
|
return await self.client.close()
|
||||||
|
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
# TODO: Implement index deletion for Memgraph
|
# TODO: Implement index deletion for Memgraph
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider)
|
||||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'CALL text_search.search("{name}", {query}) YIELD node'
|
return f'CALL text_search.search_all("{name}", {query})'
|
||||||
|
|
||||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||||
|
|
||||||
|
|
@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
||||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'CALL vector_search.cosine_similarity({vec1}, {vec2}) YIELD similarity RETURN similarity AS score'
|
return f'cosineSimilarity({vec1}, {vec2})'
|
||||||
|
|
||||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||||
|
|
||||||
|
|
@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s
|
||||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'CALL text_search.search_edges("{name}", $query) YIELD node'
|
return f'CALL text_search.search_all_edges("{name}", $query)'
|
||||||
|
|
||||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||||
|
|
|
||||||
|
|
@ -167,8 +167,8 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
SET e = edge
|
SET e = edge
|
||||||
WITH e, edge e.fact_embedding = edge.fact_embedding
|
SET e.fact_embedding = edge.fact_embedding
|
||||||
RETURN edge.uuid AS uuid
|
RETURN edge.uuid AS uuid;
|
||||||
"""
|
"""
|
||||||
case _:
|
case _:
|
||||||
return """
|
return """
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,27 @@ def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j and Memgraph
|
case GraphProvider.MEMGRAPH:
|
||||||
|
return """
|
||||||
|
UNWIND $episodes AS episode
|
||||||
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
|
FOREACH (_ IN CASE WHEN episode.group_label IS NOT NULL THEN [1] ELSE [] END |
|
||||||
|
SET n:`${episode.group_label}`
|
||||||
|
)
|
||||||
|
SET n = {
|
||||||
|
uuid: episode.uuid,
|
||||||
|
name: episode.name,
|
||||||
|
group_id: episode.group_id,
|
||||||
|
source_description: episode.source_description,
|
||||||
|
source: episode.source,
|
||||||
|
content: episode.content,
|
||||||
|
entity_edges: episode.entity_edges,
|
||||||
|
created_at: episode.created_at,
|
||||||
|
valid_at: episode.valid_at
|
||||||
|
}
|
||||||
|
RETURN n.uuid AS uuid;
|
||||||
|
"""
|
||||||
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
UNWIND $episodes AS episode
|
UNWIND $episodes AS episode
|
||||||
MERGE (n:Episodic {uuid: episode.uuid})
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
|
|
@ -235,10 +255,13 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
||||||
return """
|
return """
|
||||||
UNWIND $nodes AS node
|
UNWIND $nodes AS node
|
||||||
MERGE (n:Entity {uuid: node.uuid})
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
SET n:$(node.labels)
|
FOREACH (label IN CASE WHEN node.labels IS NOT NULL THEN node.labels ELSE [] END |
|
||||||
|
SET n:`${label}`
|
||||||
|
)
|
||||||
SET n = node
|
SET n = node
|
||||||
WITH n, node SET n.name_embedding = node.name_embedding
|
WITH n, node
|
||||||
RETURN n.uuid AS uuid
|
SET n.name_embedding = node.name_embedding
|
||||||
|
RETURN n.uuid AS uuid;
|
||||||
"""
|
"""
|
||||||
case _: # Neo4j
|
case _: # Neo4j
|
||||||
return """
|
return """
|
||||||
|
|
|
||||||
|
|
@ -563,7 +563,7 @@ async def node_fulltext_search(
|
||||||
if driver.provider == GraphProvider.KUZU:
|
if driver.provider == GraphProvider.KUZU:
|
||||||
yield_query = 'WITH node AS n, score'
|
yield_query = 'WITH node AS n, score'
|
||||||
elif driver.provider == GraphProvider.MEMGRAPH:
|
elif driver.provider == GraphProvider.MEMGRAPH:
|
||||||
yield_query = ' WITH node AS n, 1.0 AS score' # Memgraph: continue from YIELD node
|
yield_query = ' YIELD node AS n WITH n, 1.0 AS score'
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
|
||||||
|
|
@ -191,6 +191,7 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
for edge in episodic_edges:
|
for edge in episodic_edges:
|
||||||
await tx.run(episodic_edge_query, **edge.model_dump())
|
await tx.run(episodic_edge_query, **edge.model_dump())
|
||||||
else:
|
else:
|
||||||
|
print(episodes)
|
||||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||||
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue