From 721e92f8fbe0c78ab449abd230ea849101eed74d Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 24 Jul 2025 07:17:59 -0700 Subject: [PATCH] feat/falkordb dynamic graph names (#761) * graphiti-graph-name * fix-lint * fix-unittest * clone-update * groupid-none * groupid-def-fulltext * lint * Remove redundant function definition for fulltext_query in search_utils.py * Refactor get_default_group_id function and remove redundant code in falkordb_driver and search_utils. Added import statement in driver.py. * Refactor test cases in test_falkordb_driver.py for improved readability by consolidating multi-line assertions into single lines. No functional changes made. * Refactor fulltext_query function in search_utils.py to use double quotes for group_id in the filter list, enhancing consistency in query syntax. * Remove duplicate assignment of fuzzy_query in episode_fulltext_search function in search_utils.py to eliminate redundancy. * Remove duplicate assignment of fuzzy_query in community_fulltext_search function in search_utils.py to streamline code. --------- Co-authored-by: Gal Shubeli --- graphiti_core/driver/driver.py | 14 +++++++++++++- graphiti_core/driver/falkordb_driver.py | 19 +++++++++++++------ graphiti_core/driver/neo4j_driver.py | 4 +--- graphiti_core/graphiti.py | 6 +++--- tests/driver/test_falkordb_driver.py | 17 +++++------------ 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 9c8f1642..4efe230a 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import copy import logging from abc import ABC, abstractmethod from collections.abc import Coroutine @@ -49,6 +50,7 @@ class GraphDriver(ABC): fulltext_syntax: str = ( '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) + _database: str @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -63,5 +65,15 @@ class GraphDriver(ABC): raise NotImplementedError() @abstractmethod - def delete_all_indexes(self, database_: str | None = None) -> Coroutine: + def delete_all_indexes(self) -> Coroutine: raise NotImplementedError() + + def with_database(self, database: str) -> 'GraphDriver': + """ + Returns a shallow copy of this driver with a different default database. + Reuses the same connection (e.g. FalkorDB, Neo4j). + """ + cloned = copy.copy(self) + cloned._database = database + + return cloned diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index ac71c402..acf2c66f 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -90,12 +90,13 @@ class FalkorDriver(GraphDriver): The default parameters assume a local (on-premises) FalkorDB instance. """ super().__init__() + + self._database = database if falkor_db is not None: # If a FalkorDB instance is provided, use it directly self.client = falkor_db else: self.client = FalkorDB(host=host, port=port, username=username, password=password) - self._database = database self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/ @@ -106,8 +107,7 @@ class FalkorDriver(GraphDriver): return self.client.select_graph(graph_name) async def execute_query(self, cypher_query_, **kwargs: Any): - graph_name = kwargs.pop('database_', self._database) - graph = self._get_graph(graph_name) + graph = self._get_graph(self._database) # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly) params = convert_datetimes_to_strings(dict(kwargs)) @@ -151,13 +151,20 @@ class FalkorDriver(GraphDriver): elif hasattr(self.client.connection, 'close'): await self.client.connection.close() - async def delete_all_indexes(self, database_: str | None = None) -> None: - database = database_ or self._database + async def delete_all_indexes(self) -> None: await self.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', - database_=database, ) + def clone(self, database: str) -> 'GraphDriver': + """ + Returns a shallow copy of this driver with a different default database. + Reuses the same connection (e.g. FalkorDB, Neo4j). + """ + cloned = FalkorDriver(falkor_db=self.client, database=database) + + return cloned + def convert_datetimes_to_strings(obj): if isinstance(obj, dict): diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 1f542c96..bd82e8d9 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -56,9 +56,7 @@ class Neo4jDriver(GraphDriver): async def close(self) -> None: return await self.client.close() - def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]: - database = database_ or self._database + def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', - database_=database, ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index fd87234e..3459f8ae 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -113,7 +113,7 @@ class Graphiti: """ Initialize a Graphiti instance. - This constructor sets up a connection to the Neo4j database and initializes + This constructor sets up a connection to a graph database and initializes the LLM client for natural language processing tasks. Parameters @@ -148,11 +148,11 @@ class Graphiti: Notes ----- - This method establishes a connection to the Neo4j database using the provided + This method establishes a connection to a graph database (Neo4j by default) using the provided credentials. It also sets up the LLM client, either using the provided client or by creating a default OpenAIClient. - The default database name is set to 'neo4j'. If a different database name + The default database name is defined during the driver’s construction. If a different database name is required, it should be specified in the URI or set separately after initialization. diff --git a/tests/driver/test_falkordb_driver.py b/tests/driver/test_falkordb_driver.py index 735a41f9..260e24d2 100644 --- a/tests/driver/test_falkordb_driver.py +++ b/tests/driver/test_falkordb_driver.py @@ -101,11 +101,8 @@ class TestFalkorDriver: mock_graph.query = AsyncMock(return_value=mock_result) self.mock_client.select_graph.return_value = mock_graph - result = await self.driver.execute_query( - 'MATCH (n) RETURN n', param1='value1', database_='test_db' - ) + result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1') - self.mock_client.select_graph.assert_called_once_with('test_db') mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'}) result_set, header, summary = result @@ -167,11 +164,10 @@ class TestFalkorDriver: mock_graph = MagicMock() self.mock_client.select_graph.return_value = mock_graph - session = self.driver.session('test_db') + session = self.driver.session() assert isinstance(session, FalkorDriverSession) assert session.graph is mock_graph - self.mock_client.select_graph.assert_called_once_with('test_db') @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') def test_session_creation_with_none_uses_default_database(self): @@ -179,10 +175,9 @@ class TestFalkorDriver: mock_graph = MagicMock() self.mock_client.select_graph.return_value = mock_graph - session = self.driver.session(None) + session = self.driver.session() assert isinstance(session, FalkorDriverSession) - self.mock_client.select_graph.assert_called_once_with('default_db') @pytest.mark.asyncio @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') @@ -212,11 +207,9 @@ class TestFalkorDriver: async def test_delete_all_indexes(self): """Test delete_all_indexes method.""" with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute: - await self.driver.delete_all_indexes('test_db') + await self.driver.delete_all_indexes() - mock_execute.assert_called_once_with( - 'CALL db.indexes() YIELD name DROP INDEX name', database_='test_db' - ) + mock_execute.assert_called_once_with('CALL db.indexes() YIELD name DROP INDEX name') class TestFalkorDriverSession: