77 lines
3 KiB
Python
77 lines
3 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from typing import Any, Awaitable, Callable, TypeVar
|
|
|
|
from graphiti_core.driver.driver import GraphProvider
|
|
from graphiti_core.helpers import semaphore_gather
|
|
from graphiti_core.search.search_config import SearchResults
|
|
|
|
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
|
|
|
|
|
def handle_multiple_group_ids(func: F) -> F:
|
|
"""
|
|
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
|
Runs the function for each group_id separately and merges results.
|
|
"""
|
|
@functools.wraps(func)
|
|
async def wrapper(self, *args, **kwargs):
|
|
group_ids = kwargs.get('group_ids')
|
|
|
|
# Only handle FalkorDB with multiple group_ids
|
|
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
|
self.clients.driver.provider == GraphProvider.FALKORDB and
|
|
group_ids and len(group_ids) > 1):
|
|
|
|
# Execute for each group_id concurrently
|
|
driver = self.clients.driver
|
|
|
|
async def execute_for_group(gid: str):
|
|
return await func(
|
|
self,
|
|
*args,
|
|
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
|
)
|
|
|
|
results = await semaphore_gather(
|
|
*[execute_for_group(gid) for gid in group_ids],
|
|
max_coroutines=getattr(self, 'max_coroutines', None)
|
|
)
|
|
|
|
# Merge results based on type
|
|
if isinstance(results[0], SearchResults):
|
|
return SearchResults.merge(results)
|
|
elif isinstance(results[0], list):
|
|
return [item for result in results for item in result]
|
|
elif isinstance(results[0], tuple):
|
|
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
|
merged_tuple = []
|
|
for i in range(len(results[0])):
|
|
component_results = [result[i] for result in results]
|
|
if isinstance(component_results[0], list):
|
|
merged_tuple.append([item for component in component_results for item in component])
|
|
else:
|
|
merged_tuple.append(component_results)
|
|
return tuple(merged_tuple)
|
|
else:
|
|
return results
|
|
|
|
# Normal execution
|
|
return await func(self, *args, **kwargs)
|
|
|
|
return wrapper # type: ignore
|