fix lint
This commit is contained in:
parent
780fa15daa
commit
2796126cce
9 changed files with 39 additions and 35 deletions
|
|
@ -78,7 +78,6 @@ async def main():
|
|||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
try:
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
|
|
|
|||
|
|
@ -67,7 +67,6 @@ async def main():
|
|||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
|
|
|
|||
|
|
@ -31,41 +31,47 @@ def handle_multiple_group_ids(func: F) -> F:
|
|||
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
||||
Runs the function for each group_id separately and merges results.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
||||
group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index
|
||||
group_ids_pos = (
|
||||
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
|
||||
) # Adjust for zero-based index
|
||||
group_ids = kwargs.get('group_ids')
|
||||
|
||||
# If not in kwargs and position exists, get from args
|
||||
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
group_ids = args[group_ids_pos]
|
||||
|
||||
|
||||
# Only handle FalkorDB with multiple group_ids
|
||||
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
||||
self.clients.driver.provider == GraphProvider.FALKORDB and
|
||||
group_ids and len(group_ids) > 1):
|
||||
|
||||
if (
|
||||
hasattr(self, 'clients')
|
||||
and hasattr(self.clients, 'driver')
|
||||
and self.clients.driver.provider == GraphProvider.FALKORDB
|
||||
and group_ids
|
||||
and len(group_ids) > 1
|
||||
):
|
||||
# Execute for each group_id concurrently
|
||||
driver = self.clients.driver
|
||||
|
||||
|
||||
async def execute_for_group(gid: str):
|
||||
# Remove group_ids from args if it was passed positionally
|
||||
filtered_args = list(args)
|
||||
if group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
filtered_args.pop(group_ids_pos)
|
||||
|
||||
|
||||
return await func(
|
||||
self,
|
||||
*filtered_args,
|
||||
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
||||
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
|
||||
)
|
||||
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[execute_for_group(gid) for gid in group_ids],
|
||||
max_coroutines=getattr(self, 'max_coroutines', None)
|
||||
max_coroutines=getattr(self, 'max_coroutines', None),
|
||||
)
|
||||
|
||||
|
||||
# Merge results based on type
|
||||
if isinstance(results[0], SearchResults):
|
||||
return SearchResults.merge(results)
|
||||
|
|
@ -77,16 +83,18 @@ def handle_multiple_group_ids(func: F) -> F:
|
|||
for i in range(len(results[0])):
|
||||
component_results = [result[i] for result in results]
|
||||
if isinstance(component_results[0], list):
|
||||
merged_tuple.append([item for component in component_results for item in component])
|
||||
merged_tuple.append(
|
||||
[item for component in component_results for item in component]
|
||||
)
|
||||
else:
|
||||
merged_tuple.append(component_results)
|
||||
return tuple(merged_tuple)
|
||||
else:
|
||||
return results
|
||||
|
||||
|
||||
# Normal execution
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
|
|
@ -99,4 +107,4 @@ def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
|||
for idx, (name, _param) in enumerate(sig.parameters.items()):
|
||||
if name == param_name:
|
||||
return idx
|
||||
return None
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class FalkorDriver(GraphDriver):
|
|||
else:
|
||||
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
|
||||
return cloned
|
||||
|
||||
def sanitize(self, query: str) -> str:
|
||||
|
|
|
|||
|
|
@ -45,9 +45,10 @@ class Neo4jDriver(GraphDriver):
|
|||
auth=(user or '', password or ''),
|
||||
)
|
||||
self._database = database
|
||||
|
||||
|
||||
# Schedule the indices and constraints to be built
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
|
|||
|
|
@ -609,7 +609,7 @@ class Graphiti:
|
|||
"""
|
||||
if driver is None:
|
||||
driver = self.clients.driver
|
||||
|
||||
|
||||
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
||||
|
||||
async def add_episode(
|
||||
|
|
@ -688,7 +688,7 @@ class Graphiti:
|
|||
|
||||
validate_entity_types(entity_types)
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
|
||||
|
||||
if group_id is None:
|
||||
# if group_id is None, use the default group id by the provider
|
||||
# and the preset database name will be used
|
||||
|
|
@ -1012,8 +1012,7 @@ class Graphiti:
|
|||
|
||||
@handle_multiple_group_ids
|
||||
async def build_communities(
|
||||
self, group_ids: list[str] | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
"""
|
||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||
|
|
@ -1056,7 +1055,7 @@ class Graphiti:
|
|||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
driver: GraphDriver | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
|
@ -1104,7 +1103,7 @@ class Graphiti:
|
|||
search_config,
|
||||
search_filter if search_filter is not None else SearchFilters(),
|
||||
driver=driver,
|
||||
center_node_uuid=center_node_uuid
|
||||
center_node_uuid=center_node_uuid,
|
||||
)
|
||||
).edges
|
||||
|
||||
|
|
@ -1133,7 +1132,7 @@ class Graphiti:
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
driver: GraphDriver | None = None,
|
||||
) -> SearchResults:
|
||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||
|
|
@ -1150,7 +1149,7 @@ class Graphiti:
|
|||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
driver=driver
|
||||
driver=driver,
|
||||
)
|
||||
|
||||
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
||||
|
|
|
|||
|
|
@ -328,4 +328,4 @@ COMMUNITY_NODE_RETURN_NEPTUNE = """
|
|||
n.group_id AS group_id,
|
||||
n.summary AS summary,
|
||||
n.created_at AS created_at
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -132,12 +132,12 @@ class SearchResults(BaseModel):
|
|||
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
||||
"""
|
||||
Merge multiple SearchResults objects into a single SearchResults object.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
results_list : list[SearchResults]
|
||||
List of SearchResults objects to merge
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
SearchResults
|
||||
|
|
@ -145,7 +145,7 @@ class SearchResults(BaseModel):
|
|||
"""
|
||||
if not results_list:
|
||||
return cls()
|
||||
|
||||
|
||||
merged = cls()
|
||||
for result in results_list:
|
||||
merged.edges.extend(result.edges)
|
||||
|
|
@ -156,5 +156,5 @@ class SearchResults(BaseModel):
|
|||
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
||||
merged.communities.extend(result.communities)
|
||||
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
||||
|
||||
|
||||
return merged
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ from datetime import datetime
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue