fix-groupid-usage
This commit is contained in:
parent
bbf9cc6172
commit
459e708131
3 changed files with 70 additions and 56 deletions
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
|
@ -31,7 +33,13 @@ def handle_multiple_group_ids(func: F) -> F:
|
|||
"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
||||
group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index
|
||||
group_ids = kwargs.get('group_ids')
|
||||
|
||||
# If not in kwargs and position exists, get from args
|
||||
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
group_ids = args[group_ids_pos]
|
||||
|
||||
# Only handle FalkorDB with multiple group_ids
|
||||
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
||||
|
|
@ -42,9 +50,14 @@ def handle_multiple_group_ids(func: F) -> F:
|
|||
driver = self.clients.driver
|
||||
|
||||
async def execute_for_group(gid: str):
|
||||
# Remove group_ids from args if it was passed positionally
|
||||
filtered_args = list(args)
|
||||
if group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
filtered_args.pop(group_ids_pos)
|
||||
|
||||
return await func(
|
||||
self,
|
||||
*args,
|
||||
*filtered_args,
|
||||
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
||||
)
|
||||
|
||||
|
|
@ -75,3 +88,15 @@ def handle_multiple_group_ids(func: F) -> F:
|
|||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
||||
"""
|
||||
Returns the positional index of a parameter in the function signature.
|
||||
If the parameter is not found, returns None.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
for idx, (name, param) in enumerate(sig.parameters.items()):
|
||||
if name == param_name:
|
||||
return idx
|
||||
return None
|
||||
|
|
@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
|
@ -86,7 +86,7 @@ class FalkorDriver(GraphDriver):
|
|||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
falkor_db: FalkorDB | None = None,
|
||||
database: str = '\\_',
|
||||
database: str = 'default_db',
|
||||
):
|
||||
"""
|
||||
Initialize the FalkorDB driver.
|
||||
|
|
@ -94,9 +94,16 @@ class FalkorDriver(GraphDriver):
|
|||
FalkorDB is a multi-tenant graph database.
|
||||
To connect, provide the host and port.
|
||||
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||
|
||||
Args:
|
||||
host (str): The host where FalkorDB is running.
|
||||
port (int): The port on which FalkorDB is listening.
|
||||
username (str | None): The username for authentication (if required).
|
||||
password (str | None): The password for authentication (if required).
|
||||
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
|
||||
database (str): The name of the database to connect to. Defaults to 'default_db'.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._database = database
|
||||
if falkor_db is not None:
|
||||
# If a FalkorDB instance is provided, use it directly
|
||||
|
|
@ -105,7 +112,6 @@ class FalkorDriver(GraphDriver):
|
|||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
|
||||
# Schedule the indices and constraints to be built
|
||||
import asyncio
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -167,65 +173,45 @@ class FalkorDriver(GraphDriver):
|
|||
await self.client.connection.close()
|
||||
|
||||
async def delete_all_indexes(self) -> None:
|
||||
from collections import defaultdict
|
||||
|
||||
result = await self.execute_query('CALL db.indexes()')
|
||||
if result is None:
|
||||
result = await self.execute_query("CALL db.indexes()")
|
||||
if not result:
|
||||
return
|
||||
|
||||
|
||||
records, _, _ = result
|
||||
|
||||
# Organize indexes by type and label
|
||||
range_indexes = defaultdict(list)
|
||||
fulltext_indexes = defaultdict(list)
|
||||
entity_types = {}
|
||||
|
||||
drop_tasks = []
|
||||
|
||||
for record in records:
|
||||
label = record['label']
|
||||
entity_types[label] = record['entitytype']
|
||||
|
||||
for field_name, index_type in record['types'].items():
|
||||
if 'RANGE' in index_type:
|
||||
range_indexes[label].append(field_name)
|
||||
if 'FULLTEXT' in index_type:
|
||||
fulltext_indexes[label].append(field_name)
|
||||
|
||||
# Drop all range indexes
|
||||
for label, fields in range_indexes.items():
|
||||
for field in fields:
|
||||
await self.execute_query(f'DROP INDEX ON :{label}({field})')
|
||||
label = record["label"]
|
||||
entity_type = record["entitytype"]
|
||||
|
||||
# Drop all fulltext indexes
|
||||
for label, fields in fulltext_indexes.items():
|
||||
entity_type = entity_types[label]
|
||||
for field in fields:
|
||||
if entity_type == 'NODE':
|
||||
await self.execute_query(
|
||||
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
|
||||
)
|
||||
elif entity_type == 'RELATIONSHIP':
|
||||
await self.execute_query(
|
||||
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
|
||||
for field_name, index_type in record["types"].items():
|
||||
if "RANGE" in index_type:
|
||||
drop_tasks.append(
|
||||
self.execute_query(f"DROP INDEX ON :{label}({field_name})")
|
||||
)
|
||||
elif "FULLTEXT" in index_type:
|
||||
if entity_type == "NODE":
|
||||
drop_tasks.append(
|
||||
self.execute_query(
|
||||
f"DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})"
|
||||
)
|
||||
)
|
||||
elif entity_type == "RELATIONSHIP":
|
||||
drop_tasks.append(
|
||||
self.execute_query(
|
||||
f"DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})"
|
||||
)
|
||||
)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
if drop_tasks:
|
||||
await asyncio.gather(*drop_tasks)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing=False):
|
||||
if delete_existing:
|
||||
await self.delete_all_indexes()
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
self.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
|
||||
for query in index_queries:
|
||||
await self.execute_query(query)
|
||||
|
||||
def clone(self, database: str) -> 'GraphDriver':
|
||||
"""
|
||||
|
|
@ -234,6 +220,8 @@ class FalkorDriver(GraphDriver):
|
|||
"""
|
||||
if database == self._database:
|
||||
cloned = self
|
||||
elif database == self.default_group_id:
|
||||
cloned = FalkorDriver(falkor_db=self.client)
|
||||
else:
|
||||
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
|
|
|||
|
|
@ -447,6 +447,7 @@ class Graphiti:
|
|||
|
||||
if group_id is None:
|
||||
# if group_id is None, use the default group id by the provider
|
||||
# and the preset database name will be used
|
||||
group_id = self.driver.default_group_id
|
||||
else:
|
||||
validate_group_id(group_id)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue