From 459e70813179bf5bfbdbf494a0d2c4b4e33844ae Mon Sep 17 00:00:00 2001 From: Gal Shubeli Date: Thu, 14 Aug 2025 15:34:20 +0300 Subject: [PATCH] fix-groupid-usage --- graphiti_core/decorators.py | 27 ++++++- graphiti_core/driver/falkordb_driver.py | 98 +++++++++++-------------- graphiti_core/graphiti.py | 1 + 3 files changed, 70 insertions(+), 56 deletions(-) diff --git a/graphiti_core/decorators.py b/graphiti_core/decorators.py index f39f52ce..9cbdeefb 100644 --- a/graphiti_core/decorators.py +++ b/graphiti_core/decorators.py @@ -15,6 +15,8 @@ limitations under the License. """ import functools +import inspect + from typing import Any, Awaitable, Callable, TypeVar from graphiti_core.driver.driver import GraphProvider @@ -31,7 +33,13 @@ def handle_multiple_group_ids(func: F) -> F: """ @functools.wraps(func) async def wrapper(self, *args, **kwargs): + group_ids_func_pos = get_parameter_position(func, 'group_ids') + group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index group_ids = kwargs.get('group_ids') + + # If not in kwargs and position exists, get from args + if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos: + group_ids = args[group_ids_pos] # Only handle FalkorDB with multiple group_ids if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and @@ -42,9 +50,14 @@ def handle_multiple_group_ids(func: F) -> F: driver = self.clients.driver async def execute_for_group(gid: str): + # Remove group_ids from args if it was passed positionally + filtered_args = list(args) + if group_ids_pos is not None and len(args) > group_ids_pos: + filtered_args.pop(group_ids_pos) + return await func( self, - *args, + *filtered_args, **{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)}, ) @@ -75,3 +88,15 @@ def handle_multiple_group_ids(func: F) -> F: return await func(self, *args, **kwargs) return wrapper # type: ignore + + +def get_parameter_position(func: Callable, param_name: str) -> int | None: + """ + Returns the positional index of a parameter in the function signature. + If the parameter is not found, returns None. + """ + sig = inspect.signature(func) + for idx, (name, param) in enumerate(sig.parameters.items()): + if name == param_name: + return idx + return None \ No newline at end of file diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 6f684700..8dd51c76 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ - +import asyncio import logging from datetime import datetime from typing import TYPE_CHECKING, Any @@ -86,7 +86,7 @@ class FalkorDriver(GraphDriver): username: str | None = None, password: str | None = None, falkor_db: FalkorDB | None = None, - database: str = '\\_', + database: str = 'default_db', ): """ Initialize the FalkorDB driver. @@ -94,9 +94,16 @@ class FalkorDriver(GraphDriver): FalkorDB is a multi-tenant graph database. To connect, provide the host and port. The default parameters assume a local (on-premises) FalkorDB instance. + + Args: + host (str): The host where FalkorDB is running. + port (int): The port on which FalkorDB is listening. + username (str | None): The username for authentication (if required). + password (str | None): The password for authentication (if required). + falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one. + database (str): The name of the database to connect to. Defaults to 'default_db'. """ super().__init__() - self._database = database if falkor_db is not None: # If a FalkorDB instance is provided, use it directly @@ -105,7 +112,6 @@ class FalkorDriver(GraphDriver): self.client = FalkorDB(host=host, port=port, username=username, password=password) # Schedule the indices and constraints to be built - import asyncio try: # Try to get the current event loop loop = asyncio.get_running_loop() @@ -167,65 +173,45 @@ class FalkorDriver(GraphDriver): await self.client.connection.close() async def delete_all_indexes(self) -> None: - from collections import defaultdict - - result = await self.execute_query('CALL db.indexes()') - if result is None: + result = await self.execute_query("CALL db.indexes()") + if not result: return - + records, _, _ = result - - # Organize indexes by type and label - range_indexes = defaultdict(list) - fulltext_indexes = defaultdict(list) - entity_types = {} - + drop_tasks = [] + for record in records: - label = record['label'] - entity_types[label] = record['entitytype'] - - for field_name, index_type in record['types'].items(): - if 'RANGE' in index_type: - range_indexes[label].append(field_name) - if 'FULLTEXT' in index_type: - fulltext_indexes[label].append(field_name) - - # Drop all range indexes - for label, fields in range_indexes.items(): - for field in fields: - await self.execute_query(f'DROP INDEX ON :{label}({field})') + label = record["label"] + entity_type = record["entitytype"] - # Drop all fulltext indexes - for label, fields in fulltext_indexes.items(): - entity_type = entity_types[label] - for field in fields: - if entity_type == 'NODE': - await self.execute_query( - f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})' - ) - elif entity_type == 'RELATIONSHIP': - await self.execute_query( - f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})' + for field_name, index_type in record["types"].items(): + if "RANGE" in index_type: + drop_tasks.append( + self.execute_query(f"DROP INDEX ON :{label}({field_name})") ) + elif "FULLTEXT" in index_type: + if entity_type == "NODE": + drop_tasks.append( + self.execute_query( + f"DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})" + ) + ) + elif entity_type == "RELATIONSHIP": + drop_tasks.append( + self.execute_query( + f"DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})" + ) + ) - async def build_indices_and_constraints(self, delete_existing: bool = False): + if drop_tasks: + await asyncio.gather(*drop_tasks) + + async def build_indices_and_constraints(self, delete_existing=False): if delete_existing: await self.delete_all_indexes() - - range_indices: list[LiteralString] = get_range_indices(self.provider) - - fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider) - - index_queries: list[LiteralString] = range_indices + fulltext_indices - - await semaphore_gather( - *[ - self.execute_query( - query, - ) - for query in index_queries - ] - ) + index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider) + for query in index_queries: + await self.execute_query(query) def clone(self, database: str) -> 'GraphDriver': """ @@ -234,6 +220,8 @@ class FalkorDriver(GraphDriver): """ if database == self._database: cloned = self + elif database == self.default_group_id: + cloned = FalkorDriver(falkor_db=self.client) else: # Create a new instance of FalkorDriver with the same connection but a different database cloned = FalkorDriver(falkor_db=self.client, database=database) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8d0e9305..3e875073 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -447,6 +447,7 @@ class Graphiti: if group_id is None: # if group_id is None, use the default group id by the provider + # and the preset database name will be used group_id = self.driver.default_group_id else: validate_group_id(group_id)