This commit is contained in:
Naseem Ali 2025-10-18 22:00:20 +03:00
parent 780fa15daa
commit 2796126cce
9 changed files with 39 additions and 35 deletions

View file

@ -78,7 +78,6 @@ async def main():
graphiti = Graphiti(graph_driver=falkor_driver)
try:
#################################################
# ADDING EPISODES
#################################################

View file

@ -67,7 +67,6 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try:
#################################################
# ADDING EPISODES
#################################################

View file

@ -31,41 +31,47 @@ def handle_multiple_group_ids(func: F) -> F:
Decorator for FalkorDB methods that need to handle multiple group_ids.
Runs the function for each group_id separately and merges results.
"""
@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids_func_pos = get_parameter_position(func, 'group_ids')
group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index
group_ids_pos = (
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
) # Adjust for zero-based index
group_ids = kwargs.get('group_ids')
# If not in kwargs and position exists, get from args
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
group_ids = args[group_ids_pos]
# Only handle FalkorDB with multiple group_ids
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
self.clients.driver.provider == GraphProvider.FALKORDB and
group_ids and len(group_ids) > 1):
if (
hasattr(self, 'clients')
and hasattr(self.clients, 'driver')
and self.clients.driver.provider == GraphProvider.FALKORDB
and group_ids
and len(group_ids) > 1
):
# Execute for each group_id concurrently
driver = self.clients.driver
async def execute_for_group(gid: str):
# Remove group_ids from args if it was passed positionally
filtered_args = list(args)
if group_ids_pos is not None and len(args) > group_ids_pos:
filtered_args.pop(group_ids_pos)
return await func(
self,
*filtered_args,
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
)
results = await semaphore_gather(
*[execute_for_group(gid) for gid in group_ids],
max_coroutines=getattr(self, 'max_coroutines', None)
max_coroutines=getattr(self, 'max_coroutines', None),
)
# Merge results based on type
if isinstance(results[0], SearchResults):
return SearchResults.merge(results)
@ -77,16 +83,18 @@ def handle_multiple_group_ids(func: F) -> F:
for i in range(len(results[0])):
component_results = [result[i] for result in results]
if isinstance(component_results[0], list):
merged_tuple.append([item for component in component_results for item in component])
merged_tuple.append(
[item for component in component_results for item in component]
)
else:
merged_tuple.append(component_results)
return tuple(merged_tuple)
else:
return results
# Normal execution
return await func(self, *args, **kwargs)
return wrapper # type: ignore
@ -99,4 +107,4 @@ def get_parameter_position(func: Callable, param_name: str) -> int | None:
for idx, (name, _param) in enumerate(sig.parameters.items()):
if name == param_name:
return idx
return None
return None

View file

@ -260,7 +260,7 @@ class FalkorDriver(GraphDriver):
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
def sanitize(self, query: str) -> str:

View file

@ -45,9 +45,10 @@ class Neo4jDriver(GraphDriver):
auth=(user or '', password or ''),
)
self._database = database
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()

View file

@ -609,7 +609,7 @@ class Graphiti:
"""
if driver is None:
driver = self.clients.driver
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
async def add_episode(
@ -688,7 +688,7 @@ class Graphiti:
validate_entity_types(entity_types)
validate_excluded_entity_types(excluded_entity_types, entity_types)
if group_id is None:
# if group_id is None, use the default group id by the provider
# and the preset database name will be used
@ -1012,8 +1012,7 @@ class Graphiti:
@handle_multiple_group_ids
async def build_communities(
self, group_ids: list[str] | None = None,
driver: GraphDriver | None = None
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
@ -1056,7 +1055,7 @@ class Graphiti:
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
driver: GraphDriver | None = None,
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
@ -1104,7 +1103,7 @@ class Graphiti:
search_config,
search_filter if search_filter is not None else SearchFilters(),
driver=driver,
center_node_uuid=center_node_uuid
center_node_uuid=center_node_uuid,
)
).edges
@ -1133,7 +1132,7 @@ class Graphiti:
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
driver: GraphDriver | None = None,
) -> SearchResults:
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@ -1150,7 +1149,7 @@ class Graphiti:
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
bfs_origin_node_uuids,
driver=driver
driver=driver,
)
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:

View file

@ -328,4 +328,4 @@ COMMUNITY_NODE_RETURN_NEPTUNE = """
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
"""
"""

View file

@ -132,12 +132,12 @@ class SearchResults(BaseModel):
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
"""
Merge multiple SearchResults objects into a single SearchResults object.
Parameters
----------
results_list : list[SearchResults]
List of SearchResults objects to merge
Returns
-------
SearchResults
@ -145,7 +145,7 @@ class SearchResults(BaseModel):
"""
if not results_list:
return cls()
merged = cls()
for result in results_list:
merged.edges.extend(result.edges)
@ -156,5 +156,5 @@ class SearchResults(BaseModel):
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
merged.communities.extend(result.communities)
merged.community_reranker_scores.extend(result.community_reranker_scores)
return merged

View file

@ -20,8 +20,6 @@ from datetime import datetime
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,