fix-groupid-usage

This commit is contained in:
Gal Shubeli 2025-08-14 15:34:20 +03:00
parent bbf9cc6172
commit 459e708131
3 changed files with 70 additions and 56 deletions

View file

@ -15,6 +15,8 @@ limitations under the License.
"""
import functools
import inspect
from typing import Any, Awaitable, Callable, TypeVar
from graphiti_core.driver.driver import GraphProvider
@ -31,7 +33,13 @@ def handle_multiple_group_ids(func: F) -> F:
"""
@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids_func_pos = get_parameter_position(func, 'group_ids')
group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index
group_ids = kwargs.get('group_ids')
# If not in kwargs and position exists, get from args
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
group_ids = args[group_ids_pos]
# Only handle FalkorDB with multiple group_ids
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
@ -42,9 +50,14 @@ def handle_multiple_group_ids(func: F) -> F:
driver = self.clients.driver
async def execute_for_group(gid: str):
# Remove group_ids from args if it was passed positionally
filtered_args = list(args)
if group_ids_pos is not None and len(args) > group_ids_pos:
filtered_args.pop(group_ids_pos)
return await func(
self,
*args,
*filtered_args,
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
)
@ -75,3 +88,15 @@ def handle_multiple_group_ids(func: F) -> F:
return await func(self, *args, **kwargs)
return wrapper # type: ignore
def get_parameter_position(func: Callable, param_name: str) -> int | None:
"""
Returns the positional index of a parameter in the function signature.
If the parameter is not found, returns None.
"""
sig = inspect.signature(func)
for idx, (name, param) in enumerate(sig.parameters.items()):
if name == param_name:
return idx
return None

View file

@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any
@ -86,7 +86,7 @@ class FalkorDriver(GraphDriver):
username: str | None = None,
password: str | None = None,
falkor_db: FalkorDB | None = None,
database: str = '\\_',
database: str = 'default_db',
):
"""
Initialize the FalkorDB driver.
@ -94,9 +94,16 @@ class FalkorDriver(GraphDriver):
FalkorDB is a multi-tenant graph database.
To connect, provide the host and port.
The default parameters assume a local (on-premises) FalkorDB instance.
Args:
host (str): The host where FalkorDB is running.
port (int): The port on which FalkorDB is listening.
username (str | None): The username for authentication (if required).
password (str | None): The password for authentication (if required).
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
database (str): The name of the database to connect to. Defaults to 'default_db'.
"""
super().__init__()
self._database = database
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
@ -105,7 +112,6 @@ class FalkorDriver(GraphDriver):
self.client = FalkorDB(host=host, port=port, username=username, password=password)
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
@ -167,65 +173,45 @@ class FalkorDriver(GraphDriver):
await self.client.connection.close()
async def delete_all_indexes(self) -> None:
from collections import defaultdict
result = await self.execute_query('CALL db.indexes()')
if result is None:
result = await self.execute_query("CALL db.indexes()")
if not result:
return
records, _, _ = result
# Organize indexes by type and label
range_indexes = defaultdict(list)
fulltext_indexes = defaultdict(list)
entity_types = {}
drop_tasks = []
for record in records:
label = record['label']
entity_types[label] = record['entitytype']
for field_name, index_type in record['types'].items():
if 'RANGE' in index_type:
range_indexes[label].append(field_name)
if 'FULLTEXT' in index_type:
fulltext_indexes[label].append(field_name)
# Drop all range indexes
for label, fields in range_indexes.items():
for field in fields:
await self.execute_query(f'DROP INDEX ON :{label}({field})')
label = record["label"]
entity_type = record["entitytype"]
# Drop all fulltext indexes
for label, fields in fulltext_indexes.items():
entity_type = entity_types[label]
for field in fields:
if entity_type == 'NODE':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
)
elif entity_type == 'RELATIONSHIP':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
for field_name, index_type in record["types"].items():
if "RANGE" in index_type:
drop_tasks.append(
self.execute_query(f"DROP INDEX ON :{label}({field_name})")
)
elif "FULLTEXT" in index_type:
if entity_type == "NODE":
drop_tasks.append(
self.execute_query(
f"DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})"
)
)
elif entity_type == "RELATIONSHIP":
drop_tasks.append(
self.execute_query(
f"DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})"
)
)
async def build_indices_and_constraints(self, delete_existing: bool = False):
if drop_tasks:
await asyncio.gather(*drop_tasks)
async def build_indices_and_constraints(self, delete_existing=False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
for query in index_queries:
await self.execute_query(query)
def clone(self, database: str) -> 'GraphDriver':
"""
@ -234,6 +220,8 @@ class FalkorDriver(GraphDriver):
"""
if database == self._database:
cloned = self
elif database == self.default_group_id:
cloned = FalkorDriver(falkor_db=self.client)
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)

View file

@ -447,6 +447,7 @@ class Graphiti:
if group_id is None:
# if group_id is None, use the default group id by the provider
# and the preset database name will be used
group_id = self.driver.default_group_id
else:
validate_group_id(group_id)