From 489dffdc0c04336a8995e25cafd319d0ce796fce Mon Sep 17 00:00:00 2001 From: DavIvek Date: Wed, 10 Sep 2025 09:34:04 +0200 Subject: [PATCH] few improvements in mg driver --- graphiti_core/driver/memgraph_driver.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/graphiti_core/driver/memgraph_driver.py b/graphiti_core/driver/memgraph_driver.py index d419bee9..2b23b165 100644 --- a/graphiti_core/driver/memgraph_driver.py +++ b/graphiti_core/driver/memgraph_driver.py @@ -39,9 +39,7 @@ class MemgraphDriver(GraphDriver): ) self._database = database - async def execute_query( - self, cypher_query_: LiteralString, **kwargs: Any - ) -> tuple[list, Any, Any]: + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any: """ Execute a Cypher query against Memgraph using implicit transactions. Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface. @@ -49,28 +47,25 @@ class MemgraphDriver(GraphDriver): # Extract parameters from kwargs params = kwargs.pop('params', None) if params is None: - # If no 'params' key, use the remaining kwargs as parameters - # but first extract database-specific parameters database = kwargs.pop('database_', self._database) - kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param) - - # All remaining kwargs are query parameters + kwargs.pop('parameters_', None) params = kwargs else: - # Extract database parameter if params was provided separately database = kwargs.pop('database_', self._database) - kwargs.pop('parameters_', None) # Remove if present + kwargs.pop('parameters_', None) async with self.client.session(database=database) as session: try: result = await session.run(cypher_query_, params) + keys = result.keys() records = [record async for record in result] summary = await result.consume() - keys = result.keys() return (records, summary, keys) except Exception as e: logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}') raise + finally: + await session.close() def session(self, database: str | None = None) -> GraphDriverSession: _database = database or self._database @@ -80,5 +75,4 @@ class MemgraphDriver(GraphDriver): return await self.client.close() def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: - # TODO: Implement index deletion for Memgraph - raise NotImplementedError('Index deletion not implemented for MemgraphDriver') + return self.client.execute_query('DROP ALL INDEXES')