diff --git a/examples/quickstart/quickstart_falkordb.py b/examples/quickstart/quickstart_falkordb.py index 19b7cce8..e35e958b 100644 --- a/examples/quickstart/quickstart_falkordb.py +++ b/examples/quickstart/quickstart_falkordb.py @@ -78,8 +78,6 @@ async def main(): graphiti = Graphiti(graph_driver=falkor_driver) try: - # Initialize the graph database with graphiti's indices. This only needs to be done once. - await graphiti.build_indices_and_constraints() ################################################# # ADDING EPISODES diff --git a/examples/quickstart/quickstart_neo4j.py b/examples/quickstart/quickstart_neo4j.py index d92a6c5c..e34d6eed 100644 --- a/examples/quickstart/quickstart_neo4j.py +++ b/examples/quickstart/quickstart_neo4j.py @@ -67,8 +67,6 @@ async def main(): graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) try: - # Initialize the graph database with graphiti's indices. This only needs to be done once. - await graphiti.build_indices_and_constraints() ################################################# # ADDING EPISODES diff --git a/graphiti_core/decorators.py b/graphiti_core/decorators.py new file mode 100644 index 00000000..f39f52ce --- /dev/null +++ b/graphiti_core/decorators.py @@ -0,0 +1,77 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +from typing import Any, Awaitable, Callable, TypeVar + +from graphiti_core.driver.driver import GraphProvider +from graphiti_core.helpers import semaphore_gather +from graphiti_core.search.search_config import SearchResults + +F = TypeVar('F', bound=Callable[..., Awaitable[Any]]) + + +def handle_multiple_group_ids(func: F) -> F: + """ + Decorator for FalkorDB methods that need to handle multiple group_ids. + Runs the function for each group_id separately and merges results. + """ + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + group_ids = kwargs.get('group_ids') + + # Only handle FalkorDB with multiple group_ids + if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and + self.clients.driver.provider == GraphProvider.FALKORDB and + group_ids and len(group_ids) > 1): + + # Execute for each group_id concurrently + driver = self.clients.driver + + async def execute_for_group(gid: str): + return await func( + self, + *args, + **{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)}, + ) + + results = await semaphore_gather( + *[execute_for_group(gid) for gid in group_ids], + max_coroutines=getattr(self, 'max_coroutines', None) + ) + + # Merge results based on type + if isinstance(results[0], SearchResults): + return SearchResults.merge(results) + elif isinstance(results[0], list): + return [item for result in results for item in result] + elif isinstance(results[0], tuple): + # Handle tuple outputs (like build_communities returning (nodes, edges)) + merged_tuple = [] + for i in range(len(results[0])): + component_results = [result[i] for result in results] + if isinstance(component_results[0], list): + merged_tuple.append([item for component in component_results for item in component]) + else: + merged_tuple.append(component_results) + return tuple(merged_tuple) + else: + return results + + # Normal execution + return await func(self, *args, **kwargs) + + return wrapper # type: ignore diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 6b85d80e..d38a2598 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -57,6 +57,7 @@ class GraphDriver(ABC): '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) _database: str + default_group_id: str = '' @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: @@ -74,12 +75,10 @@ class GraphDriver(ABC): def delete_all_indexes(self) -> Coroutine: raise NotImplementedError() - def with_database(self, database: str) -> 'GraphDriver': - """ - Returns a shallow copy of this driver with a different default database. - Reuses the same connection (e.g. FalkorDB, Neo4j). - """ - cloned = copy.copy(self) - cloned._database = database + @abstractmethod + async def build_indices_and_constraints(self, delete_existing: bool = False): + raise NotImplementedError() - return cloned + def clone(self, database: str) -> 'GraphDriver': + """Clone the driver with a different database or graph name.""" + return self diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index f121319b..6f684700 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -18,6 +18,8 @@ import logging from datetime import datetime from typing import TYPE_CHECKING, Any +from typing_extensions import LiteralString + if TYPE_CHECKING: from falkordb import Graph as FalkorGraph from falkordb.asyncio import FalkorDB @@ -33,6 +35,8 @@ else: ) from None from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices +from graphiti_core.helpers import semaphore_gather logger = logging.getLogger(__name__) @@ -72,6 +76,8 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): provider = GraphProvider.FALKORDB + default_group_id: str = '\\_' + fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries def __init__( self, @@ -80,7 +86,7 @@ class FalkorDriver(GraphDriver): username: str | None = None, password: str | None = None, falkor_db: FalkorDB | None = None, - database: str = 'default_db', + database: str = '\\_', ): """ Initialize the FalkorDB driver. @@ -98,7 +104,16 @@ class FalkorDriver(GraphDriver): else: self.client = FalkorDB(host=host, port=port, username=username, password=password) - self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/ + # Schedule the indices and constraints to be built + import asyncio + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # Schedule the build_indices_and_constraints to run + loop.create_task(self.build_indices_and_constraints()) + except RuntimeError: + # No event loop running, this will be handled later + pass def _get_graph(self, graph_name: str | None) -> FalkorGraph: # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db" @@ -152,8 +167,64 @@ class FalkorDriver(GraphDriver): await self.client.connection.close() async def delete_all_indexes(self) -> None: - await self.execute_query( - 'CALL db.indexes() YIELD name DROP INDEX name', + from collections import defaultdict + + result = await self.execute_query('CALL db.indexes()') + if result is None: + return + + records, _, _ = result + + # Organize indexes by type and label + range_indexes = defaultdict(list) + fulltext_indexes = defaultdict(list) + entity_types = {} + + for record in records: + label = record['label'] + entity_types[label] = record['entitytype'] + + for field_name, index_type in record['types'].items(): + if 'RANGE' in index_type: + range_indexes[label].append(field_name) + if 'FULLTEXT' in index_type: + fulltext_indexes[label].append(field_name) + + # Drop all range indexes + for label, fields in range_indexes.items(): + for field in fields: + await self.execute_query(f'DROP INDEX ON :{label}({field})') + + # Drop all fulltext indexes + for label, fields in fulltext_indexes.items(): + entity_type = entity_types[label] + for field in fields: + if entity_type == 'NODE': + await self.execute_query( + f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})' + ) + elif entity_type == 'RELATIONSHIP': + await self.execute_query( + f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})' + ) + + async def build_indices_and_constraints(self, delete_existing: bool = False): + if delete_existing: + await self.delete_all_indexes() + + range_indices: list[LiteralString] = get_range_indices(self.provider) + + fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider) + + index_queries: list[LiteralString] = range_indices + fulltext_indices + + await semaphore_gather( + *[ + self.execute_query( + query, + ) + for query in index_queries + ] ) def clone(self, database: str) -> 'GraphDriver': @@ -161,8 +232,12 @@ class FalkorDriver(GraphDriver): Returns a shallow copy of this driver with a different default database. Reuses the same connection (e.g. FalkorDB, Neo4j). """ - cloned = FalkorDriver(falkor_db=self.client, database=database) - + if database == self._database: + cloned = self + else: + # Create a new instance of FalkorDriver with the same connection but a different database + cloned = FalkorDriver(falkor_db=self.client, database=database) + return cloned diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 7ac9a5a8..0d90b4f1 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices +from graphiti_core.helpers import semaphore_gather logger = logging.getLogger(__name__) class Neo4jDriver(GraphDriver): provider = GraphProvider.NEO4J + default_group_id: str = '' def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): super().__init__() @@ -36,6 +39,17 @@ class Neo4jDriver(GraphDriver): auth=(user or '', password or ''), ) self._database = database + + # Schedule the indices and constraints to be built + import asyncio + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # Schedule the build_indices_and_constraints to run + loop.create_task(self.build_indices_and_constraints()) + except RuntimeError: + # No event loop running, this will be handled later + pass async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: # Check if database_ is provided in kwargs. @@ -64,3 +78,22 @@ class Neo4jDriver(GraphDriver): return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) + + async def build_indices_and_constraints(self, delete_existing: bool = False): + if delete_existing: + await self.delete_all_indexes() + + range_indices: list[LiteralString] = get_range_indices(self.provider) + + fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider) + + index_queries: list[LiteralString] = range_indices + fulltext_indices + + await semaphore_gather( + *[ + self.execute_query( + query, + ) + for query in index_queries + ] + ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8e58a5ae..8d0e9305 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -24,6 +24,7 @@ from typing_extensions import LiteralString from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.decorators import handle_multiple_group_ids from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import ( @@ -35,7 +36,6 @@ from graphiti_core.edges import ( from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( - get_default_group_id, semaphore_gather, validate_excluded_entity_types, validate_group_id, @@ -87,7 +87,6 @@ from graphiti_core.utils.maintenance.edge_operations import ( ) from graphiti_core.utils.maintenance.graph_data_operations import ( EPISODE_WINDOW_LEN, - build_indices_and_constraints, retrieve_episodes, ) from graphiti_core.utils.maintenance.node_operations import ( @@ -320,25 +319,26 @@ class Graphiti: ----- This method should typically be called once during the initial setup of the knowledge graph or when updating the database schema. It uses the - `build_indices_and_constraints` function from the - `graphiti_core.utils.maintenance.graph_data_operations` module to perform + driver's `build_indices_and_constraints` method to perform the actual database operations. The specific indices and constraints created depend on the implementation - of the `build_indices_and_constraints` function. Refer to that function's - documentation for details on the exact database schema modifications. + of the driver's `build_indices_and_constraints` method. Refer to the specific + driver documentation for details on the exact database schema modifications. Caution: Running this method on a large existing database may take some time and could impact database performance during execution. """ - await build_indices_and_constraints(self.driver, delete_existing) + await self.driver.build_indices_and_constraints(delete_existing) + @handle_multiple_group_ids async def retrieve_episodes( self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, group_ids: list[str] | None = None, source: EpisodeType | None = None, + driver: GraphDriver | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -365,7 +365,10 @@ class Graphiti: The actual retrieval is performed by the `retrieve_episodes` function from the `graphiti_core.utils` module. """ - return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) + if driver is None: + driver = self.clients.driver + + return await retrieve_episodes(driver, reference_time, last_n, group_ids, source) async def add_episode( self, @@ -442,12 +445,18 @@ class Graphiti: start = time() now = utc_now() - # if group_id is None, use the default group id by the provider - group_id = group_id or get_default_group_id(self.driver.provider) - validate_entity_types(entity_types) + if group_id is None: + # if group_id is None, use the default group id by the provider + group_id = self.driver.default_group_id + else: + validate_group_id(group_id) + if group_id != self.driver._database: + # if group_id is provided, use it as the database name + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver + validate_entity_types(entity_types) validate_excluded_entity_types(excluded_entity_types, entity_types) - validate_group_id(group_id) previous_episodes = ( await self.retrieve_episodes( @@ -620,9 +629,15 @@ class Graphiti: start = time() now = utc_now() - # if group_id is None, use the default group id by the provider - group_id = group_id or get_default_group_id(self.driver.provider) - validate_group_id(group_id) + if group_id is None: + # if group_id is None, use the default group id by the provider + group_id = self.driver.default_group_id + else: + validate_group_id(group_id) + if group_id != self.driver._database: + # if group_id is provided, use it as the database name + self.driver = self.driver.clone(database=group_id) + self.clients.driver = self.driver # Create default edge type map edge_type_map_default = ( @@ -850,21 +865,26 @@ class Graphiti: except Exception as e: raise e + @handle_multiple_group_ids async def build_communities( - self, group_ids: list[str] | None = None + self, group_ids: list[str] | None = None, + driver: GraphDriver | None = None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising the content of these communities. ---------- - query : list[str] | None + group_ids : list[str] | None Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. """ + if driver is None: + driver = self.clients.driver + # Clear existing communities - await remove_communities(self.driver) + await remove_communities(driver) community_nodes, community_edges = await build_communities( - self.driver, self.llm_client, group_ids + driver, self.llm_client, group_ids ) await semaphore_gather( @@ -873,16 +893,17 @@ class Graphiti: ) await semaphore_gather( - *[node.save(self.driver) for node in community_nodes], + *[node.save(driver) for node in community_nodes], max_coroutines=self.max_coroutines, ) await semaphore_gather( - *[edge.save(self.driver) for edge in community_edges], + *[edge.save(driver) for edge in community_edges], max_coroutines=self.max_coroutines, ) return community_nodes, community_edges + @handle_multiple_group_ids async def search( self, query: str, @@ -890,6 +911,7 @@ class Graphiti: group_ids: list[str] | None = None, num_results=DEFAULT_SEARCH_LIMIT, search_filter: SearchFilters | None = None, + driver: GraphDriver | None = None ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -936,7 +958,8 @@ class Graphiti: group_ids, search_config, search_filter if search_filter is not None else SearchFilters(), - center_node_uuid, + driver=driver, + center_node_uuid=center_node_uuid ) ).edges @@ -956,6 +979,7 @@ class Graphiti: query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter ) + @handle_multiple_group_ids async def search_( self, query: str, @@ -964,6 +988,7 @@ class Graphiti: center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, search_filter: SearchFilters | None = None, + driver: GraphDriver | None = None ) -> SearchResults: """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and @@ -980,6 +1005,7 @@ class Graphiti: search_filter if search_filter is not None else SearchFilters(), center_node_uuid, bfs_origin_node_uuids, + driver=driver ) async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults: diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 9feb3073..eca6fb8e 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -53,17 +53,6 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None ) -def get_default_group_id(provider: GraphProvider) -> str: - """ - This function differentiates the default group id based on the database type. - For most databases, the default group id is an empty string, while there are database types that require a specific default group id. - """ - if provider == GraphProvider.FALKORDB: - return '_' - else: - return '' - - def lucene_sanitize(query: str) -> str: # Escape special characters from a query before passing into Lucene # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 1458def7..a4e3b248 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -72,10 +72,11 @@ async def search( center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, query_vector: list[float] | None = None, + driver: GraphDriver | None = None, ) -> SearchResults: start = time() - driver = clients.driver + driver = driver or clients.driver embedder = clients.embedder cross_encoder = clients.cross_encoder diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index f24a3f3e..14220007 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -127,3 +127,34 @@ class SearchResults(BaseModel): episode_reranker_scores: list[float] = Field(default_factory=list) communities: list[CommunityNode] = Field(default_factory=list) community_reranker_scores: list[float] = Field(default_factory=list) + + @classmethod + def merge(cls, results_list: list['SearchResults']) -> 'SearchResults': + """ + Merge multiple SearchResults objects into a single SearchResults object. + + Parameters + ---------- + results_list : list[SearchResults] + List of SearchResults objects to merge + + Returns + ------- + SearchResults + A single SearchResults object containing all results + """ + if not results_list: + return cls() + + merged = cls() + for result in results_list: + merged.edges.extend(result.edges) + merged.edge_reranker_scores.extend(result.edge_reranker_scores) + merged.nodes.extend(result.nodes) + merged.node_reranker_scores.extend(result.node_reranker_scores) + merged.episodes.extend(result.episodes) + merged.episode_reranker_scores.extend(result.episode_reranker_scores) + merged.communities.extend(result.communities) + merged.community_reranker_scores.extend(result.community_reranker_scores) + + return merged diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index b866607f..bc7b3af5 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -20,8 +20,6 @@ from datetime import datetime from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver -from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices -from graphiti_core.helpers import semaphore_gather from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record @@ -30,39 +28,6 @@ EPISODE_WINDOW_LEN = 3 logger = logging.getLogger(__name__) -async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): - if delete_existing: - records, _, _ = await driver.execute_query( - """ - SHOW INDEXES YIELD name - """, - ) - index_names = [record['name'] for record in records] - await semaphore_gather( - *[ - driver.execute_query( - """DROP INDEX $name""", - name=name, - ) - for name in index_names - ] - ) - range_indices: list[LiteralString] = get_range_indices(driver.provider) - - fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) - - index_queries: list[LiteralString] = range_indices + fulltext_indices - - await semaphore_gather( - *[ - driver.execute_query( - query, - ) - for query in index_queries - ] - ) - - async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async with driver.session() as session: