few improvements in mg driver
This commit is contained in:
parent
e43a756ac1
commit
489dffdc0c
1 changed files with 7 additions and 13 deletions
|
|
@ -39,9 +39,7 @@ class MemgraphDriver(GraphDriver):
|
|||
)
|
||||
self._database = database
|
||||
|
||||
async def execute_query(
|
||||
self, cypher_query_: LiteralString, **kwargs: Any
|
||||
) -> tuple[list, Any, Any]:
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Execute a Cypher query against Memgraph using implicit transactions.
|
||||
Returns a tuple of (records, summary, keys) for compatibility with the GraphDriver interface.
|
||||
|
|
@ -49,28 +47,25 @@ class MemgraphDriver(GraphDriver):
|
|||
# Extract parameters from kwargs
|
||||
params = kwargs.pop('params', None)
|
||||
if params is None:
|
||||
# If no 'params' key, use the remaining kwargs as parameters
|
||||
# but first extract database-specific parameters
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None) # Remove if present (Neo4j async driver param)
|
||||
|
||||
# All remaining kwargs are query parameters
|
||||
kwargs.pop('parameters_', None)
|
||||
params = kwargs
|
||||
else:
|
||||
# Extract database parameter if params was provided separately
|
||||
database = kwargs.pop('database_', self._database)
|
||||
kwargs.pop('parameters_', None) # Remove if present
|
||||
kwargs.pop('parameters_', None)
|
||||
|
||||
async with self.client.session(database=database) as session:
|
||||
try:
|
||||
result = await session.run(cypher_query_, params)
|
||||
keys = result.keys()
|
||||
records = [record async for record in result]
|
||||
summary = await result.consume()
|
||||
keys = result.keys()
|
||||
return (records, summary, keys)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing Memgraph query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
def session(self, database: str | None = None) -> GraphDriverSession:
|
||||
_database = database or self._database
|
||||
|
|
@ -80,5 +75,4 @@ class MemgraphDriver(GraphDriver):
|
|||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||
# TODO: Implement index deletion for Memgraph
|
||||
raise NotImplementedError('Index deletion not implemented for MemgraphDriver')
|
||||
return self.client.execute_query('DROP ALL INDEXES')
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue