wip fix quick start issues
This commit is contained in:
parent
0f9dd11cc8
commit
b0d0041429
3 changed files with 69 additions and 48 deletions
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from neo4j import AsyncGraphDatabase, EagerResult
|
from neo4j import GraphDatabase
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
@ -31,37 +31,54 @@ class MemgraphDriver(GraphDriver):
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'memgraph'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = AsyncGraphDatabase.driver(
|
self.client = GraphDatabase.driver(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
auth=(user or '', password or ''),
|
auth=(user or '', password or ''),
|
||||||
)
|
)
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> tuple[list, Any, Any]:
|
||||||
# Check if database_ is provided in kwargs.
|
"""
|
||||||
# If not populated, set the value to retain backwards compatibility
|
Execute a Cypher query against Memgraph using implicit transactions.
|
||||||
|
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
||||||
|
"""
|
||||||
|
# Extract parameters from kwargs
|
||||||
params = kwargs.pop('params', None)
|
params = kwargs.pop('params', None)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
# If no 'params' key, use the remaining kwargs as parameters
|
||||||
params.setdefault('database_', self._database)
|
# but first extract database-specific parameters
|
||||||
|
database = kwargs.pop('database_', self._database)
|
||||||
|
kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param)
|
||||||
|
|
||||||
try:
|
# All remaining kwargs are query parameters
|
||||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
params = kwargs
|
||||||
except Exception as e:
|
else:
|
||||||
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
# Extract database parameter if params was provided separately
|
||||||
raise
|
database = kwargs.pop('database_', self._database)
|
||||||
|
kwargs.pop('parameters_', None) # Remove if present
|
||||||
|
|
||||||
return result
|
with self.client.session(database=database) as session:
|
||||||
|
try:
|
||||||
|
# Debug: Print the query and parameters
|
||||||
|
print(f"DEBUG - Memgraph Query: {cypher_query_}")
|
||||||
|
print(f"DEBUG - Memgraph Params: {params}")
|
||||||
|
|
||||||
|
result = session.run(cypher_query_, params)
|
||||||
|
records = list(result)
|
||||||
|
summary = result.consume()
|
||||||
|
keys = result.keys()
|
||||||
|
return (records, summary, keys)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
||||||
|
raise
|
||||||
|
|
||||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||||
_database = database or self._database
|
_database = database or self._database
|
||||||
return self.client.session(database=_database) # type: ignore
|
return self.client.session(database=_database) # type: ignore
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
return await self.client.close()
|
return self.client.close()
|
||||||
|
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
# TODO
|
# TODO: Implement index deletion for Memgraph
|
||||||
return self.client.execute_query(
|
raise NotImplementedError("Index deletion not implemented for MemgraphDriver")
|
||||||
'SHOW INDEX INFO;',
|
|
||||||
)
|
|
||||||
|
|
@ -47,26 +47,26 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return [
|
return [
|
||||||
'CREATE INDEX ON :Entity(uuid)',
|
'CREATE INDEX ON :Entity(uuid);',
|
||||||
'CREATE INDEX ON :Entity(group_id)',
|
'CREATE INDEX ON :Entity(group_id);',
|
||||||
'CREATE INDEX ON :Entity(name)',
|
'CREATE INDEX ON :Entity(name);',
|
||||||
'CREATE INDEX ON :Entity(created_at)',
|
'CREATE INDEX ON :Entity(created_at);',
|
||||||
'CREATE INDEX ON :Episodic(uuid)',
|
'CREATE INDEX ON :Episodic(uuid);',
|
||||||
'CREATE INDEX ON :Episodic(group_id)',
|
'CREATE INDEX ON :Episodic(group_id);',
|
||||||
'CREATE INDEX ON :Episodic(created_at)',
|
'CREATE INDEX ON :Episodic(created_at);',
|
||||||
'CREATE INDEX ON :Episodic(valid_at)',
|
'CREATE INDEX ON :Episodic(valid_at);',
|
||||||
'CREATE INDEX ON :Community(uuid)',
|
'CREATE INDEX ON :Community(uuid);',
|
||||||
'CREATE INDEX ON :Community(group_id)',
|
'CREATE INDEX ON :Community(group_id);',
|
||||||
'CREATE INDEX ON :RELATES_TO(uuid)',
|
'CREATE INDEX ON :RELATES_TO(uuid);',
|
||||||
'CREATE INDEX ON :RELATES_TO(group_id)',
|
'CREATE INDEX ON :RELATES_TO(group_id);',
|
||||||
'CREATE INDEX ON :RELATES_TO(name)',
|
'CREATE INDEX ON :RELATES_TO(name);',
|
||||||
'CREATE INDEX ON :RELATES_TO(created_at)',
|
'CREATE INDEX ON :RELATES_TO(created_at);',
|
||||||
'CREATE INDEX ON :RELATES_TO(expired_at)',
|
'CREATE INDEX ON :RELATES_TO(expired_at);',
|
||||||
'CREATE INDEX ON :RELATES_TO(valid_at)',
|
'CREATE INDEX ON :RELATES_TO(valid_at);',
|
||||||
'CREATE INDEX ON :RELATES_TO(invalid_at)',
|
'CREATE INDEX ON :RELATES_TO(invalid_at);',
|
||||||
'CREATE INDEX ON :MENTIONS(uuid)',
|
'CREATE INDEX ON :MENTIONS(uuid);',
|
||||||
'CREATE INDEX ON :MENTIONS(group_id)',
|
'CREATE INDEX ON :MENTIONS(group_id);',
|
||||||
'CREATE INDEX ON :HAS_MEMBER(uuid)',
|
'CREATE INDEX ON :HAS_MEMBER(uuid);',
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -112,10 +112,10 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return [
|
return [
|
||||||
"""CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id)""",
|
"""CREATE TEXT INDEX episode_content ON :Episodic(content, source, source_description, group_id);""",
|
||||||
"""CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id)""",
|
"""CREATE TEXT INDEX node_name_and_summary ON :Entity(name, summary, group_id);""",
|
||||||
"""CREATE TEXT INDEX community_name ON :Community(name, group_id)""",
|
"""CREATE TEXT INDEX community_name ON :Community(name, group_id);""",
|
||||||
"""CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id)""",
|
"""CREATE TEXT EDGE INDEX edge_name_and_fact ON :RELATES_TO(name, fact, group_id);""",
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -140,7 +140,7 @@ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider)
|
||||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'CALL text_search.search("{name}", {query}) YIELD node RETURN node LIMIT $limit'
|
return f'CALL text_search.search("{name}", {query}) YIELD node'
|
||||||
|
|
||||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||||
|
|
||||||
|
|
@ -154,7 +154,7 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
||||||
return f'array_cosine_similarity({vec1}, {vec2})'
|
return f'array_cosine_similarity({vec1}, {vec2})'
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return "TODO"
|
return f'CALL vector_search.cosine_similarity({vec1}, {vec2}) YIELD similarity RETURN similarity AS score'
|
||||||
|
|
||||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||||
|
|
||||||
|
|
@ -169,6 +169,6 @@ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> s
|
||||||
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
||||||
|
|
||||||
if provider == GraphProvider.MEMGRAPH:
|
if provider == GraphProvider.MEMGRAPH:
|
||||||
return f'CALL text_search.search("{name}", $query) YIELD node RETURN node LIMIT $limit'
|
return f'CALL text_search.search_edges("{name}", $query) YIELD node'
|
||||||
|
|
||||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||||
|
|
|
||||||
|
|
@ -562,6 +562,8 @@ async def node_fulltext_search(
|
||||||
yield_query = 'YIELD node AS n, score'
|
yield_query = 'YIELD node AS n, score'
|
||||||
if driver.provider == GraphProvider.KUZU:
|
if driver.provider == GraphProvider.KUZU:
|
||||||
yield_query = 'WITH node AS n, score'
|
yield_query = 'WITH node AS n, score'
|
||||||
|
elif driver.provider == GraphProvider.MEMGRAPH:
|
||||||
|
yield_query = ' WITH node AS n, 1.0 AS score' # Memgraph: continue from YIELD node
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
@ -968,6 +970,8 @@ async def community_fulltext_search(
|
||||||
yield_query = 'YIELD node AS c, score'
|
yield_query = 'YIELD node AS c, score'
|
||||||
if driver.provider == GraphProvider.KUZU:
|
if driver.provider == GraphProvider.KUZU:
|
||||||
yield_query = 'WITH node AS c, score'
|
yield_query = 'WITH node AS c, score'
|
||||||
|
elif driver.provider == GraphProvider.MEMGRAPH:
|
||||||
|
yield_query = ' WITH node AS c, 1.0 AS score' # Memgraph: continue from YIELD node
|
||||||
|
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue