graphiti/graphiti_core/driver/falkordb_driver.py
Daniel Chalef cf29de4565 Apply ruff formatting to falkordb driver and node queries
- Quote style fixes in falkordb_driver.py
- Trailing whitespace cleanup in node_db_queries.py
- Update uv.lock

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-12 09:46:43 -07:00

308 lines
10 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB
else:
try:
from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB
except ImportError:
# If falkordb is not installed, raise an ImportError
raise ImportError(
'falkordb is required for FalkorDriver. '
'Install it with: pip install graphiti-core[falkordb]'
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
logger = logging.getLogger(__name__)
STOPWORDS = [
'a',
'is',
'the',
'an',
'and',
'are',
'as',
'at',
'be',
'but',
'by',
'for',
'if',
'in',
'into',
'it',
'no',
'not',
'of',
'on',
'or',
'such',
'that',
'their',
'then',
'there',
'these',
'they',
'this',
'to',
'was',
'will',
'with',
]
class FalkorDriverSession(GraphDriverSession):
provider = GraphProvider.FALKORDB
def __init__(self, graph: FalkorGraph):
self.graph = graph
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Falkor, but method must exist
pass
async def close(self):
# No explicit close needed for FalkorDB, but method must exist
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, query: str | list, **kwargs: Any) -> Any:
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
if isinstance(query, list):
for cypher, params in query:
params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType]
else:
params = dict(kwargs)
params = convert_datetimes_to_strings(params)
await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType]
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
return None
class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB
aoss_client: None = None
def __init__(
self,
host: str = 'localhost',
port: int = 6379,
username: str | None = None,
password: str | None = None,
falkor_db: FalkorDB | None = None,
database: str = 'default_db',
):
"""
Initialize the FalkorDB driver.
FalkorDB is a multi-tenant graph database.
To connect, provide the host and port.
The default parameters assume a local (on-premises) FalkorDB instance.
"""
super().__init__()
self._database = database
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
self.client = falkor_db
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
if graph_name is None:
graph_name = self._database
return self.client.select_graph(graph_name)
async def execute_query(self, cypher_query_, **kwargs: Any):
graph = self._get_graph(self._database)
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
params = convert_datetimes_to_strings(dict(kwargs))
try:
result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType]
except Exception as e:
if 'already indexed' in str(e):
# check if index already exists
logger.info(f'Index already exists: {e}')
return None
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
raise
# Convert the result header to a list of strings
header = [h[1] for h in result.header]
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
records = []
for row in result.result_set:
record = {}
for i, field_name in enumerate(header):
if i < len(row):
record[field_name] = row[i]
else:
# If there are more fields in header than values in row, set to None
record[field_name] = None
records.append(record)
return records, header, None
def session(self, database: str | None = None) -> GraphDriverSession:
return FalkorDriverSession(self._get_graph(database))
async def close(self) -> None:
"""Close the driver connection."""
if hasattr(self.client, 'aclose'):
await self.client.aclose() # type: ignore[reportUnknownMemberType]
elif hasattr(self.client.connection, 'aclose'):
await self.client.connection.aclose()
elif hasattr(self.client.connection, 'close'):
await self.client.connection.close()
async def delete_all_indexes(self) -> None:
result = await self.execute_query('CALL db.indexes()')
if not result:
return
records, _, _ = result
drop_tasks = []
for record in records:
label = record['label']
entity_type = record['entitytype']
for field_name, index_type in record['types'].items():
if 'RANGE' in index_type:
drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
elif 'FULLTEXT' in index_type:
if entity_type == 'NODE':
drop_tasks.append(
self.execute_query(
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
)
)
elif entity_type == 'RELATIONSHIP':
drop_tasks.append(
self.execute_query(
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
)
)
if drop_tasks:
await asyncio.gather(*drop_tasks)
def clone(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
def sanitize(self, query: str) -> str:
"""
Replace FalkorDB special characters with whitespace.
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
"""
# FalkorDB separator characters that break text into tokens
separator_map = str.maketrans(
{
',': ' ',
'.': ' ',
'<': ' ',
'>': ' ',
'{': ' ',
'}': ' ',
'[': ' ',
']': ' ',
'"': ' ',
"'": ' ',
':': ' ',
';': ' ',
'!': ' ',
'@': ' ',
'#': ' ',
'$': ' ',
'%': ' ',
'^': ' ',
'&': ' ',
'*': ' ',
'(': ' ',
')': ' ',
'-': ' ',
'+': ' ',
'=': ' ',
'~': ' ',
'?': ' ',
}
)
sanitized = query.translate(separator_map)
# Clean up multiple spaces
sanitized = ' '.join(sanitized.split())
return sanitized
def build_fulltext_query(
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
) -> str:
"""
Build a fulltext query string for FalkorDB using RedisSearch syntax.
FalkorDB uses RedisSearch-like syntax where:
- Field queries use @ prefix: @field:value
- Multiple values for same field: (@field:value1|value2)
- Text search doesn't need @ prefix for content fields
- AND is implicit with space: (@group_id:value) (text)
- OR uses pipe within parentheses: (@group_id:value1|value2)
"""
if group_ids is None or len(group_ids) == 0:
group_filter = ''
else:
group_values = '|'.join(group_ids)
group_filter = f'(@group_id:{group_values})'
sanitized_query = self.sanitize(query)
# Remove stopwords from the sanitized query
query_words = sanitized_query.split()
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
sanitized_query = ' | '.join(filtered_words)
# If the query is too long return no query
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
return ''
full_query = group_filter + ' (' + sanitized_query + ')'
return full_query