diff --git a/examples/quickstart/quickstart_falkordb.py b/examples/quickstart/quickstart_falkordb.py index 19b7cce8..ea101a73 100644 --- a/examples/quickstart/quickstart_falkordb.py +++ b/examples/quickstart/quickstart_falkordb.py @@ -78,9 +78,6 @@ async def main(): graphiti = Graphiti(graph_driver=falkor_driver) try: - # Initialize the graph database with graphiti's indices. This only needs to be done once. - await graphiti.build_indices_and_constraints() - ################################################# # ADDING EPISODES ################################################# diff --git a/examples/quickstart/quickstart_neo4j.py b/examples/quickstart/quickstart_neo4j.py index d92a6c5c..29db600a 100644 --- a/examples/quickstart/quickstart_neo4j.py +++ b/examples/quickstart/quickstart_neo4j.py @@ -67,9 +67,6 @@ async def main(): graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) try: - # Initialize the graph database with graphiti's indices. This only needs to be done once. - await graphiti.build_indices_and_constraints() - ################################################# # ADDING EPISODES ################################################# diff --git a/graphiti_core/decorators.py b/graphiti_core/decorators.py new file mode 100644 index 00000000..9a4fd903 --- /dev/null +++ b/graphiti_core/decorators.py @@ -0,0 +1,110 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar + +from graphiti_core.driver.driver import GraphProvider +from graphiti_core.helpers import semaphore_gather +from graphiti_core.search.search_config import SearchResults + +F = TypeVar('F', bound=Callable[..., Awaitable[Any]]) + + +def handle_multiple_group_ids(func: F) -> F: + """ + Decorator for FalkorDB methods that need to handle multiple group_ids. + Runs the function for each group_id separately and merges results. + """ + + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + group_ids_func_pos = get_parameter_position(func, 'group_ids') + group_ids_pos = ( + group_ids_func_pos - 1 if group_ids_func_pos is not None else None + ) # Adjust for zero-based index + group_ids = kwargs.get('group_ids') + + # If not in kwargs and position exists, get from args + if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos: + group_ids = args[group_ids_pos] + + # Only handle FalkorDB with multiple group_ids + if ( + hasattr(self, 'clients') + and hasattr(self.clients, 'driver') + and self.clients.driver.provider == GraphProvider.FALKORDB + and group_ids + and len(group_ids) > 1 + ): + # Execute for each group_id concurrently + driver = self.clients.driver + + async def execute_for_group(gid: str): + # Remove group_ids from args if it was passed positionally + filtered_args = list(args) + if group_ids_pos is not None and len(args) > group_ids_pos: + filtered_args.pop(group_ids_pos) + + return await func( + self, + *filtered_args, + **{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)}, + ) + + results = await semaphore_gather( + *[execute_for_group(gid) for gid in group_ids], + max_coroutines=getattr(self, 'max_coroutines', None), + ) + + # Merge results based on type + if isinstance(results[0], SearchResults): + return SearchResults.merge(results) + elif isinstance(results[0], list): + return [item for result in results for item in result] + elif isinstance(results[0], tuple): + # Handle tuple outputs (like build_communities returning (nodes, edges)) + merged_tuple = [] + for i in range(len(results[0])): + component_results = [result[i] for result in results] + if isinstance(component_results[0], list): + merged_tuple.append( + [item for component in component_results for item in component] + ) + else: + merged_tuple.append(component_results) + return tuple(merged_tuple) + else: + return results + + # Normal execution + return await func(self, *args, **kwargs) + + return wrapper # type: ignore + + +def get_parameter_position(func: Callable, param_name: str) -> int | None: + """ + Returns the positional index of a parameter in the function signature. + If the parameter is not found, returns None. + """ + sig = inspect.signature(func) + for idx, (name, _param) in enumerate(sig.parameters.items()): + if name == param_name: + return idx + return None diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 5b4e0fc3..c1a355f3 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -76,6 +76,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str + default_group_id: str = '' search_interface: SearchInterface | None = None graph_operations_interface: GraphOperationsInterface | None = None @@ -105,6 +106,14 @@ class GraphDriver(ABC): return cloned + @abstractmethod + async def build_indices_and_constraints(self, delete_existing: bool = False): + raise NotImplementedError() + + def clone(self, database: str) -> 'GraphDriver': + """Clone the driver with a different database or graph name.""" + return self + def build_fulltext_query( self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128 ) -> str: diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index d0b4ffe8..de469d53 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -34,6 +34,7 @@ else: ) from None from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings logger = logging.getLogger(__name__) @@ -112,6 +113,8 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): provider = GraphProvider.FALKORDB + default_group_id: str = '\\_' + fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries aoss_client: None = None def __init__( @@ -129,9 +132,16 @@ class FalkorDriver(GraphDriver): FalkorDB is a multi-tenant graph database. To connect, provide the host and port. The default parameters assume a local (on-premises) FalkorDB instance. + + Args: + host (str): The host where FalkorDB is running. + port (int): The port on which FalkorDB is listening. + username (str | None): The username for authentication (if required). + password (str | None): The password for authentication (if required). + falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one. + database (str): The name of the database to connect to. Defaults to 'default_db'. """ super().__init__() - self._database = database if falkor_db is not None: # If a FalkorDB instance is provided, use it directly @@ -139,7 +149,15 @@ class FalkorDriver(GraphDriver): else: self.client = FalkorDB(host=host, port=port, username=username, password=password) - self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/ + # Schedule the indices and constraints to be built + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # Schedule the build_indices_and_constraints to run + loop.create_task(self.build_indices_and_constraints()) + except RuntimeError: + # No event loop running, this will be handled later + pass def _get_graph(self, graph_name: str | None) -> FalkorGraph: # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db" @@ -224,12 +242,25 @@ class FalkorDriver(GraphDriver): if drop_tasks: await asyncio.gather(*drop_tasks) + async def build_indices_and_constraints(self, delete_existing=False): + if delete_existing: + await self.delete_all_indexes() + index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider) + for query in index_queries: + await self.execute_query(query) + def clone(self, database: str) -> 'GraphDriver': """ Returns a shallow copy of this driver with a different default database. Reuses the same connection (e.g. FalkorDB, Neo4j). """ - cloned = FalkorDriver(falkor_db=self.client, database=database) + if database == self._database: + cloned = self + elif database == self.default_group_id: + cloned = FalkorDriver(falkor_db=self.client) + else: + # Create a new instance of FalkorDriver with the same connection but a different database + cloned = FalkorDriver(falkor_db=self.client, database=database) return cloned diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 4a0baf79..5d85c2b0 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices +from graphiti_core.helpers import semaphore_gather logger = logging.getLogger(__name__) class Neo4jDriver(GraphDriver): provider = GraphProvider.NEO4J + default_group_id: str = '' def __init__( self, @@ -43,6 +46,18 @@ class Neo4jDriver(GraphDriver): ) self._database = database + # Schedule the indices and constraints to be built + import asyncio + + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # Schedule the build_indices_and_constraints to run + loop.create_task(self.build_indices_and_constraints()) + except RuntimeError: + # No event loop running, this will be handled later + pass + self.aoss_client = None async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: @@ -73,6 +88,25 @@ class Neo4jDriver(GraphDriver): 'CALL db.indexes() YIELD name DROP INDEX name', ) + async def build_indices_and_constraints(self, delete_existing: bool = False): + if delete_existing: + await self.delete_all_indexes() + + range_indices: list[LiteralString] = get_range_indices(self.provider) + + fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider) + + index_queries: list[LiteralString] = range_indices + fulltext_indices + + await semaphore_gather( + *[ + self.execute_query( + query, + ) + for query in index_queries + ] + ) + async def health_check(self) -> None: """Check Neo4j connectivity by running the driver's verify_connectivity method.""" try: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index bebbdc7c..c4fa60c5 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -24,6 +24,7 @@ from typing_extensions import LiteralString from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.decorators import handle_multiple_group_ids from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import ( @@ -87,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import ( ) from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, - build_indices_and_constraints, retrieve_episodes, ) from graphiti_core.utils.maintenance.node_operations import ( @@ -340,18 +340,17 @@ class Graphiti: ----- This method should typically be called once during the initial setup of the knowledge graph or when updating the database schema. It uses the - `build_indices_and_constraints` function from the - `graphiti_core.utils.maintenance.graph_data_operations` module to perform + driver's `build_indices_and_constraints` method to perform the actual database operations. The specific indices and constraints created depend on the implementation - of the `build_indices_and_constraints` function. Refer to that function's - documentation for details on the exact database schema modifications. + of the driver's `build_indices_and_constraints` method. Refer to the specific + driver documentation for details on the exact database schema modifications. Caution: Running this method on a large existing database may take some time and could impact database performance during execution. """ - await build_indices_and_constraints(self.driver, delete_existing) + await self.driver.build_indices_and_constraints(delete_existing) async def _extract_and_resolve_nodes( self, @@ -574,12 +573,14 @@ class Graphiti: return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map + @handle_multiple_group_ids async def retrieve_episodes( self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, group_ids: list[str] | None = None, source: EpisodeType | None = None, + driver: GraphDriver | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -606,7 +607,10 @@ class Graphiti: The actual retrieval is performed by the `retrieve_episodes` function from the `graphiti_core.utils` module. """ - return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) + if driver is None: + driver = self.clients.driver + + return await retrieve_episodes(driver, reference_time, last_n, group_ids, source) async def add_episode( self, @@ -683,11 +687,18 @@ class Graphiti: now = utc_now() validate_entity_types(entity_types) - validate_excluded_entity_types(excluded_entity_types, entity_types) - validate_group_id(group_id) - # if group_id is None, use the default group id by the provider - group_id = group_id or get_default_group_id(self.driver.provider) + + if group_id is None: + # if group_id is None, use the default group id by the provider + # and the preset database name will be used + group_id = get_default_group_id(self.driver.provider) + else: + validate_group_id(group_id) + if group_id != self.driver._database: + # if group_id is provided, use it as the database name + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver with self.tracer.start_span('add_episode') as span: try: @@ -865,8 +876,14 @@ class Graphiti: now = utc_now() # if group_id is None, use the default group id by the provider - group_id = group_id or get_default_group_id(self.driver.provider) - validate_group_id(group_id) + if group_id is None: + group_id = get_default_group_id(self.driver.provider) + else: + validate_group_id(group_id) + if group_id != self.driver._database: + # if group_id is provided, use it as the database name + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver # Create default edge type map edge_type_map_default = ( @@ -993,21 +1010,25 @@ class Graphiti: bulk_span.record_exception(e) raise e + @handle_multiple_group_ids async def build_communities( - self, group_ids: list[str] | None = None + self, group_ids: list[str] | None = None, driver: GraphDriver | None = None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising the content of these communities. ---------- - query : list[str] | None + group_ids : list[str] | None Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. """ + if driver is None: + driver = self.clients.driver + # Clear existing communities - await remove_communities(self.driver) + await remove_communities(driver) community_nodes, community_edges = await build_communities( - self.driver, self.llm_client, group_ids + driver, self.llm_client, group_ids ) await semaphore_gather( @@ -1016,16 +1037,17 @@ class Graphiti: ) await semaphore_gather( - *[node.save(self.driver) for node in community_nodes], + *[node.save(driver) for node in community_nodes], max_coroutines=self.max_coroutines, ) await semaphore_gather( - *[edge.save(self.driver) for edge in community_edges], + *[edge.save(driver) for edge in community_edges], max_coroutines=self.max_coroutines, ) return community_nodes, community_edges + @handle_multiple_group_ids async def search( self, query: str, @@ -1033,6 +1055,7 @@ class Graphiti: group_ids: list[str] | None = None, num_results=DEFAULT_SEARCH_LIMIT, search_filter: SearchFilters | None = None, + driver: GraphDriver | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -1079,7 +1102,8 @@ class Graphiti: group_ids, search_config, search_filter if search_filter is not None else SearchFilters(), - center_node_uuid, + driver=driver, + center_node_uuid=center_node_uuid, ) ).edges @@ -1099,6 +1123,7 @@ class Graphiti: query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter ) + @handle_multiple_group_ids async def search_( self, query: str, @@ -1107,6 +1132,7 @@ class Graphiti: center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, search_filter: SearchFilters | None = None, + driver: GraphDriver | None = None, ) -> SearchResults: """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and @@ -1123,6 +1149,7 @@ class Graphiti: search_filter if search_filter is not None else SearchFilters(), center_node_uuid, bfs_origin_node_uuids, + driver=driver, ) async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults: diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index f30b5144..af98f560 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -74,10 +74,11 @@ async def search( center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, query_vector: list[float] | None = None, + driver: GraphDriver | None = None, ) -> SearchResults: start = time() - driver = clients.driver + driver = driver or clients.driver embedder = clients.embedder cross_encoder = clients.cross_encoder diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index f24a3f3e..7e5714f5 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -127,3 +127,34 @@ class SearchResults(BaseModel): episode_reranker_scores: list[float] = Field(default_factory=list) communities: list[CommunityNode] = Field(default_factory=list) community_reranker_scores: list[float] = Field(default_factory=list) + + @classmethod + def merge(cls, results_list: list['SearchResults']) -> 'SearchResults': + """ + Merge multiple SearchResults objects into a single SearchResults object. + + Parameters + ---------- + results_list : list[SearchResults] + List of SearchResults objects to merge + + Returns + ------- + SearchResults + A single SearchResults object containing all results + """ + if not results_list: + return cls() + + merged = cls() + for result in results_list: + merged.edges.extend(result.edges) + merged.edge_reranker_scores.extend(result.edge_reranker_scores) + merged.nodes.extend(result.nodes) + merged.node_reranker_scores.extend(result.node_reranker_scores) + merged.episodes.extend(result.episodes) + merged.episode_reranker_scores.extend(result.episode_reranker_scores) + merged.communities.extend(result.communities) + merged.community_reranker_scores.extend(result.community_reranker_scores) + + return merged diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index ee7e6be2..7a1b1a69 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -20,8 +20,6 @@ from datetime import datetime from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphProvider -from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices -from graphiti_core.helpers import semaphore_gather from graphiti_core.models.nodes.node_db_queries import ( EPISODIC_NODE_RETURN, EPISODIC_NODE_RETURN_NEPTUNE, @@ -33,46 +31,6 @@ EPISODE_WINDOW_LEN = 3 logger = logging.getLogger(__name__) -async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): - if delete_existing: - await driver.delete_all_indexes() - - range_indices: list[LiteralString] = get_range_indices(driver.provider) - - # Don't create fulltext indices if search_interface is being used - if not driver.search_interface: - fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) - - if driver.provider == GraphProvider.KUZU: - # Skip creating fulltext indices if they already exist. Need to do this manually - # until Kuzu supports `IF NOT EXISTS` for indices. - result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;') - if len(result) > 0: - fulltext_indices = [] - - # Only load the `fts` extension if it's not already loaded, otherwise throw an error. - result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;') - if len(result) == 0: - fulltext_indices.insert( - 0, - """ - INSTALL fts; - LOAD fts; - """, - ) - - index_queries: list[LiteralString] = range_indices + fulltext_indices - - await semaphore_gather( - *[ - driver.execute_query( - query, - ) - for query in index_queries - ] - ) - - async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async with driver.session() as session: