fix-groupid-usage
This commit is contained in:
parent
bbf9cc6172
commit
459e708131
3 changed files with 70 additions and 56 deletions
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, TypeVar
|
from typing import Any, Awaitable, Callable, TypeVar
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphProvider
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
|
@ -31,7 +33,13 @@ def handle_multiple_group_ids(func: F) -> F:
|
||||||
"""
|
"""
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(self, *args, **kwargs):
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
||||||
|
group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index
|
||||||
group_ids = kwargs.get('group_ids')
|
group_ids = kwargs.get('group_ids')
|
||||||
|
|
||||||
|
# If not in kwargs and position exists, get from args
|
||||||
|
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
||||||
|
group_ids = args[group_ids_pos]
|
||||||
|
|
||||||
# Only handle FalkorDB with multiple group_ids
|
# Only handle FalkorDB with multiple group_ids
|
||||||
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
||||||
|
|
@ -42,9 +50,14 @@ def handle_multiple_group_ids(func: F) -> F:
|
||||||
driver = self.clients.driver
|
driver = self.clients.driver
|
||||||
|
|
||||||
async def execute_for_group(gid: str):
|
async def execute_for_group(gid: str):
|
||||||
|
# Remove group_ids from args if it was passed positionally
|
||||||
|
filtered_args = list(args)
|
||||||
|
if group_ids_pos is not None and len(args) > group_ids_pos:
|
||||||
|
filtered_args.pop(group_ids_pos)
|
||||||
|
|
||||||
return await func(
|
return await func(
|
||||||
self,
|
self,
|
||||||
*args,
|
*filtered_args,
|
||||||
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -75,3 +88,15 @@ def handle_multiple_group_ids(func: F) -> F:
|
||||||
return await func(self, *args, **kwargs)
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper # type: ignore
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
||||||
|
"""
|
||||||
|
Returns the positional index of a parameter in the function signature.
|
||||||
|
If the parameter is not found, returns None.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
for idx, (name, param) in enumerate(sig.parameters.items()):
|
||||||
|
if name == param_name:
|
||||||
|
return idx
|
||||||
|
return None
|
||||||
|
|
@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
@ -86,7 +86,7 @@ class FalkorDriver(GraphDriver):
|
||||||
username: str | None = None,
|
username: str | None = None,
|
||||||
password: str | None = None,
|
password: str | None = None,
|
||||||
falkor_db: FalkorDB | None = None,
|
falkor_db: FalkorDB | None = None,
|
||||||
database: str = '\\_',
|
database: str = 'default_db',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the FalkorDB driver.
|
Initialize the FalkorDB driver.
|
||||||
|
|
@ -94,9 +94,16 @@ class FalkorDriver(GraphDriver):
|
||||||
FalkorDB is a multi-tenant graph database.
|
FalkorDB is a multi-tenant graph database.
|
||||||
To connect, provide the host and port.
|
To connect, provide the host and port.
|
||||||
The default parameters assume a local (on-premises) FalkorDB instance.
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (str): The host where FalkorDB is running.
|
||||||
|
port (int): The port on which FalkorDB is listening.
|
||||||
|
username (str | None): The username for authentication (if required).
|
||||||
|
password (str | None): The password for authentication (if required).
|
||||||
|
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
|
||||||
|
database (str): The name of the database to connect to. Defaults to 'default_db'.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._database = database
|
self._database = database
|
||||||
if falkor_db is not None:
|
if falkor_db is not None:
|
||||||
# If a FalkorDB instance is provided, use it directly
|
# If a FalkorDB instance is provided, use it directly
|
||||||
|
|
@ -105,7 +112,6 @@ class FalkorDriver(GraphDriver):
|
||||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||||
|
|
||||||
# Schedule the indices and constraints to be built
|
# Schedule the indices and constraints to be built
|
||||||
import asyncio
|
|
||||||
try:
|
try:
|
||||||
# Try to get the current event loop
|
# Try to get the current event loop
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -167,65 +173,45 @@ class FalkorDriver(GraphDriver):
|
||||||
await self.client.connection.close()
|
await self.client.connection.close()
|
||||||
|
|
||||||
async def delete_all_indexes(self) -> None:
|
async def delete_all_indexes(self) -> None:
|
||||||
from collections import defaultdict
|
result = await self.execute_query("CALL db.indexes()")
|
||||||
|
if not result:
|
||||||
result = await self.execute_query('CALL db.indexes()')
|
|
||||||
if result is None:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
records, _, _ = result
|
records, _, _ = result
|
||||||
|
drop_tasks = []
|
||||||
# Organize indexes by type and label
|
|
||||||
range_indexes = defaultdict(list)
|
|
||||||
fulltext_indexes = defaultdict(list)
|
|
||||||
entity_types = {}
|
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
label = record['label']
|
label = record["label"]
|
||||||
entity_types[label] = record['entitytype']
|
entity_type = record["entitytype"]
|
||||||
|
|
||||||
for field_name, index_type in record['types'].items():
|
|
||||||
if 'RANGE' in index_type:
|
|
||||||
range_indexes[label].append(field_name)
|
|
||||||
if 'FULLTEXT' in index_type:
|
|
||||||
fulltext_indexes[label].append(field_name)
|
|
||||||
|
|
||||||
# Drop all range indexes
|
|
||||||
for label, fields in range_indexes.items():
|
|
||||||
for field in fields:
|
|
||||||
await self.execute_query(f'DROP INDEX ON :{label}({field})')
|
|
||||||
|
|
||||||
# Drop all fulltext indexes
|
for field_name, index_type in record["types"].items():
|
||||||
for label, fields in fulltext_indexes.items():
|
if "RANGE" in index_type:
|
||||||
entity_type = entity_types[label]
|
drop_tasks.append(
|
||||||
for field in fields:
|
self.execute_query(f"DROP INDEX ON :{label}({field_name})")
|
||||||
if entity_type == 'NODE':
|
|
||||||
await self.execute_query(
|
|
||||||
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
|
|
||||||
)
|
|
||||||
elif entity_type == 'RELATIONSHIP':
|
|
||||||
await self.execute_query(
|
|
||||||
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
|
|
||||||
)
|
)
|
||||||
|
elif "FULLTEXT" in index_type:
|
||||||
|
if entity_type == "NODE":
|
||||||
|
drop_tasks.append(
|
||||||
|
self.execute_query(
|
||||||
|
f"DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif entity_type == "RELATIONSHIP":
|
||||||
|
drop_tasks.append(
|
||||||
|
self.execute_query(
|
||||||
|
f"DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
if drop_tasks:
|
||||||
|
await asyncio.gather(*drop_tasks)
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing=False):
|
||||||
if delete_existing:
|
if delete_existing:
|
||||||
await self.delete_all_indexes()
|
await self.delete_all_indexes()
|
||||||
|
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
|
||||||
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
for query in index_queries:
|
||||||
|
await self.execute_query(query)
|
||||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
|
||||||
|
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
||||||
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
self.execute_query(
|
|
||||||
query,
|
|
||||||
)
|
|
||||||
for query in index_queries
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def clone(self, database: str) -> 'GraphDriver':
|
def clone(self, database: str) -> 'GraphDriver':
|
||||||
"""
|
"""
|
||||||
|
|
@ -234,6 +220,8 @@ class FalkorDriver(GraphDriver):
|
||||||
"""
|
"""
|
||||||
if database == self._database:
|
if database == self._database:
|
||||||
cloned = self
|
cloned = self
|
||||||
|
elif database == self.default_group_id:
|
||||||
|
cloned = FalkorDriver(falkor_db=self.client)
|
||||||
else:
|
else:
|
||||||
# Create a new instance of FalkorDriver with the same connection but a different database
|
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||||
|
|
|
||||||
|
|
@ -447,6 +447,7 @@ class Graphiti:
|
||||||
|
|
||||||
if group_id is None:
|
if group_id is None:
|
||||||
# if group_id is None, use the default group id by the provider
|
# if group_id is None, use the default group id by the provider
|
||||||
|
# and the preset database name will be used
|
||||||
group_id = self.driver.default_group_id
|
group_id = self.driver.default_group_id
|
||||||
else:
|
else:
|
||||||
validate_group_id(group_id)
|
validate_group_id(group_id)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue