feat/falkordb dynamic graph names (#761)

* graphiti-graph-name

* fix-lint

* fix-unittest

* clone-update

* groupid-none

* groupid-def-fulltext

* lint

* Remove redundant function definition for fulltext_query in search_utils.py

* Refactor get_default_group_id function and remove redundant code in falkordb_driver and search_utils. Added import statement in driver.py.

* Refactor test cases in test_falkordb_driver.py for improved readability by consolidating multi-line assertions into single lines. No functional changes made.

* Refactor fulltext_query function in search_utils.py to use double quotes for group_id in the filter list, enhancing consistency in query syntax.

* Remove duplicate assignment of fuzzy_query in episode_fulltext_search function in search_utils.py to eliminate redundancy.

* Remove duplicate assignment of fuzzy_query in community_fulltext_search function in search_utils.py to streamline code.

---------

Co-authored-by: Gal Shubeli <galshubeli93@gmail.com>
This commit is contained in:
Daniel Chalef 2025-07-24 07:17:59 -07:00 committed by GitHub
parent 17747ff58d
commit 721e92f8fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 25 deletions

View file

@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import logging
from abc import ABC, abstractmethod
from collections.abc import Coroutine
@ -49,6 +50,7 @@ class GraphDriver(ABC):
fulltext_syntax: str = (
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -63,5 +65,15 @@ class GraphDriver(ABC):
raise NotImplementedError()
@abstractmethod
def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
def delete_all_indexes(self) -> Coroutine:
raise NotImplementedError()
def with_database(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = copy.copy(self)
cloned._database = database
return cloned

View file

@ -90,12 +90,13 @@ class FalkorDriver(GraphDriver):
The default parameters assume a local (on-premises) FalkorDB instance.
"""
super().__init__()
self._database = database
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
self.client = falkor_db
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self._database = database
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
@ -106,8 +107,7 @@ class FalkorDriver(GraphDriver):
return self.client.select_graph(graph_name)
async def execute_query(self, cypher_query_, **kwargs: Any):
graph_name = kwargs.pop('database_', self._database)
graph = self._get_graph(graph_name)
graph = self._get_graph(self._database)
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
params = convert_datetimes_to_strings(dict(kwargs))
@ -151,13 +151,20 @@ class FalkorDriver(GraphDriver):
elif hasattr(self.client.connection, 'close'):
await self.client.connection.close()
async def delete_all_indexes(self, database_: str | None = None) -> None:
database = database_ or self._database
async def delete_all_indexes(self) -> None:
await self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database,
)
def clone(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):

View file

@ -56,9 +56,7 @@ class Neo4jDriver(GraphDriver):
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
database = database_ or self._database
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database,
)

View file

@ -113,7 +113,7 @@ class Graphiti:
"""
Initialize a Graphiti instance.
This constructor sets up a connection to the Neo4j database and initializes
This constructor sets up a connection to a graph database and initializes
the LLM client for natural language processing tasks.
Parameters
@ -148,11 +148,11 @@ class Graphiti:
Notes
-----
This method establishes a connection to the Neo4j database using the provided
This method establishes a connection to a graph database (Neo4j by default) using the provided
credentials. It also sets up the LLM client, either using the provided client
or by creating a default OpenAIClient.
The default database name is set to 'neo4j'. If a different database name
The default database name is defined during the drivers construction. If a different database name
is required, it should be specified in the URI or set separately after
initialization.

View file

@ -101,11 +101,8 @@ class TestFalkorDriver:
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
result = await self.driver.execute_query(
'MATCH (n) RETURN n', param1='value1', database_='test_db'
)
result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
self.mock_client.select_graph.assert_called_once_with('test_db')
mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
result_set, header, summary = result
@ -167,11 +164,10 @@ class TestFalkorDriver:
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session('test_db')
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
assert session.graph is mock_graph
self.mock_client.select_graph.assert_called_once_with('test_db')
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_session_creation_with_none_uses_default_database(self):
@ -179,10 +175,9 @@ class TestFalkorDriver:
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session(None)
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
self.mock_client.select_graph.assert_called_once_with('default_db')
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
@ -212,11 +207,9 @@ class TestFalkorDriver:
async def test_delete_all_indexes(self):
"""Test delete_all_indexes method."""
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
await self.driver.delete_all_indexes('test_db')
await self.driver.delete_all_indexes()
mock_execute.assert_called_once_with(
'CALL db.indexes() YIELD name DROP INDEX name', database_='test_db'
)
mock_execute.assert_called_once_with('CALL db.indexes() YIELD name DROP INDEX name')
class TestFalkorDriverSession: