From 2796126cce61a238b58248a370500494d66ff365 Mon Sep 17 00:00:00 2001 From: Naseem Ali <34807727+Naseem77@users.noreply.github.com> Date: Sat, 18 Oct 2025 22:00:20 +0300 Subject: [PATCH] fix lint --- examples/quickstart/quickstart_falkordb.py | 1 - examples/quickstart/quickstart_neo4j.py | 1 - graphiti_core/decorators.py | 40 +++++++++++-------- graphiti_core/driver/falkordb_driver.py | 2 +- graphiti_core/driver/neo4j_driver.py | 3 +- graphiti_core/graphiti.py | 15 ++++--- graphiti_core/models/nodes/node_db_queries.py | 2 +- graphiti_core/search/search_config.py | 8 ++-- .../maintenance/graph_data_operations.py | 2 - 9 files changed, 39 insertions(+), 35 deletions(-) diff --git a/examples/quickstart/quickstart_falkordb.py b/examples/quickstart/quickstart_falkordb.py index e35e958b..ea101a73 100644 --- a/examples/quickstart/quickstart_falkordb.py +++ b/examples/quickstart/quickstart_falkordb.py @@ -78,7 +78,6 @@ async def main(): graphiti = Graphiti(graph_driver=falkor_driver) try: - ################################################# # ADDING EPISODES ################################################# diff --git a/examples/quickstart/quickstart_neo4j.py b/examples/quickstart/quickstart_neo4j.py index e34d6eed..29db600a 100644 --- a/examples/quickstart/quickstart_neo4j.py +++ b/examples/quickstart/quickstart_neo4j.py @@ -67,7 +67,6 @@ async def main(): graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) try: - ################################################# # ADDING EPISODES ################################################# diff --git a/graphiti_core/decorators.py b/graphiti_core/decorators.py index a8548659..9a4fd903 100644 --- a/graphiti_core/decorators.py +++ b/graphiti_core/decorators.py @@ -31,41 +31,47 @@ def handle_multiple_group_ids(func: F) -> F: Decorator for FalkorDB methods that need to handle multiple group_ids. Runs the function for each group_id separately and merges results. """ + @functools.wraps(func) async def wrapper(self, *args, **kwargs): group_ids_func_pos = get_parameter_position(func, 'group_ids') - group_ids_pos = group_ids_func_pos - 1 if group_ids_func_pos is not None else None # Adjust for zero-based index + group_ids_pos = ( + group_ids_func_pos - 1 if group_ids_func_pos is not None else None + ) # Adjust for zero-based index group_ids = kwargs.get('group_ids') # If not in kwargs and position exists, get from args if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos: group_ids = args[group_ids_pos] - + # Only handle FalkorDB with multiple group_ids - if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and - self.clients.driver.provider == GraphProvider.FALKORDB and - group_ids and len(group_ids) > 1): - + if ( + hasattr(self, 'clients') + and hasattr(self.clients, 'driver') + and self.clients.driver.provider == GraphProvider.FALKORDB + and group_ids + and len(group_ids) > 1 + ): # Execute for each group_id concurrently driver = self.clients.driver - + async def execute_for_group(gid: str): # Remove group_ids from args if it was passed positionally filtered_args = list(args) if group_ids_pos is not None and len(args) > group_ids_pos: filtered_args.pop(group_ids_pos) - + return await func( self, *filtered_args, - **{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)}, + **{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)}, ) - + results = await semaphore_gather( *[execute_for_group(gid) for gid in group_ids], - max_coroutines=getattr(self, 'max_coroutines', None) + max_coroutines=getattr(self, 'max_coroutines', None), ) - + # Merge results based on type if isinstance(results[0], SearchResults): return SearchResults.merge(results) @@ -77,16 +83,18 @@ def handle_multiple_group_ids(func: F) -> F: for i in range(len(results[0])): component_results = [result[i] for result in results] if isinstance(component_results[0], list): - merged_tuple.append([item for component in component_results for item in component]) + merged_tuple.append( + [item for component in component_results for item in component] + ) else: merged_tuple.append(component_results) return tuple(merged_tuple) else: return results - + # Normal execution return await func(self, *args, **kwargs) - + return wrapper # type: ignore @@ -99,4 +107,4 @@ def get_parameter_position(func: Callable, param_name: str) -> int | None: for idx, (name, _param) in enumerate(sig.parameters.items()): if name == param_name: return idx - return None \ No newline at end of file + return None diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index fa608a9b..d134ec41 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -260,7 +260,7 @@ class FalkorDriver(GraphDriver): else: # Create a new instance of FalkorDriver with the same connection but a different database cloned = FalkorDriver(falkor_db=self.client, database=database) - + return cloned def sanitize(self, query: str) -> str: diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 2c07ea04..507468ce 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -45,9 +45,10 @@ class Neo4jDriver(GraphDriver): auth=(user or '', password or ''), ) self._database = database - + # Schedule the indices and constraints to be built import asyncio + try: # Try to get the current event loop loop = asyncio.get_running_loop() diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 51bc492d..c4fa60c5 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -609,7 +609,7 @@ class Graphiti: """ if driver is None: driver = self.clients.driver - + return await retrieve_episodes(driver, reference_time, last_n, group_ids, source) async def add_episode( @@ -688,7 +688,7 @@ class Graphiti: validate_entity_types(entity_types) validate_excluded_entity_types(excluded_entity_types, entity_types) - + if group_id is None: # if group_id is None, use the default group id by the provider # and the preset database name will be used @@ -1012,8 +1012,7 @@ class Graphiti: @handle_multiple_group_ids async def build_communities( - self, group_ids: list[str] | None = None, - driver: GraphDriver | None = None + self, group_ids: list[str] | None = None, driver: GraphDriver | None = None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising @@ -1056,7 +1055,7 @@ class Graphiti: group_ids: list[str] | None = None, num_results=DEFAULT_SEARCH_LIMIT, search_filter: SearchFilters | None = None, - driver: GraphDriver | None = None + driver: GraphDriver | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -1104,7 +1103,7 @@ class Graphiti: search_config, search_filter if search_filter is not None else SearchFilters(), driver=driver, - center_node_uuid=center_node_uuid + center_node_uuid=center_node_uuid, ) ).edges @@ -1133,7 +1132,7 @@ class Graphiti: center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, search_filter: SearchFilters | None = None, - driver: GraphDriver | None = None + driver: GraphDriver | None = None, ) -> SearchResults: """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and @@ -1150,7 +1149,7 @@ class Graphiti: search_filter if search_filter is not None else SearchFilters(), center_node_uuid, bfs_origin_node_uuids, - driver=driver + driver=driver, ) async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults: diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 3604025d..34e3d8b8 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -328,4 +328,4 @@ COMMUNITY_NODE_RETURN_NEPTUNE = """ n.group_id AS group_id, n.summary AS summary, n.created_at AS created_at -""" \ No newline at end of file +""" diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 14220007..7e5714f5 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -132,12 +132,12 @@ class SearchResults(BaseModel): def merge(cls, results_list: list['SearchResults']) -> 'SearchResults': """ Merge multiple SearchResults objects into a single SearchResults object. - + Parameters ---------- results_list : list[SearchResults] List of SearchResults objects to merge - + Returns ------- SearchResults @@ -145,7 +145,7 @@ class SearchResults(BaseModel): """ if not results_list: return cls() - + merged = cls() for result in results_list: merged.edges.extend(result.edges) @@ -156,5 +156,5 @@ class SearchResults(BaseModel): merged.episode_reranker_scores.extend(result.episode_reranker_scores) merged.communities.extend(result.communities) merged.community_reranker_scores.extend(result.community_reranker_scores) - + return merged diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 4cc15908..7a1b1a69 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -20,8 +20,6 @@ from datetime import datetime from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphProvider -from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices -from graphiti_core.helpers import semaphore_gather from graphiti_core.models.nodes.node_db_queries import ( EPISODIC_NODE_RETURN, EPISODIC_NODE_RETURN_NEPTUNE,