[Improvement] Add GraphID isolation support for FalkorDB multi-tenant architecture (#835)

* Update node_db_queries.py

* Update node_db_queries.py

* graph-per-graphid

* fix-groupid-usage

* ruff-fix

* rev-driver-changes

* rm-un-changes

* fix lint

---------

Co-authored-by: Naseem Ali <34807727+Naseem77@users.noreply.github.com>
This commit is contained in:
Gal Shubeli 2025-11-03 22:56:53 +07:00 committed by GitHub
parent 8d99984204
commit c144ff5995
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 267 additions and 72 deletions

View file

@ -78,9 +78,6 @@ async def main():
graphiti = Graphiti(graph_driver=falkor_driver)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES
#################################################

View file

@ -67,9 +67,6 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES
#################################################

110
graphiti_core/decorators.py Normal file
View file

@ -0,0 +1,110 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
import inspect
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from graphiti_core.driver.driver import GraphProvider
from graphiti_core.helpers import semaphore_gather
from graphiti_core.search.search_config import SearchResults
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
def handle_multiple_group_ids(func: F) -> F:
"""
Decorator for FalkorDB methods that need to handle multiple group_ids.
Runs the function for each group_id separately and merges results.
"""
@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids_func_pos = get_parameter_position(func, 'group_ids')
group_ids_pos = (
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
) # Adjust for zero-based index
group_ids = kwargs.get('group_ids')
# If not in kwargs and position exists, get from args
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
group_ids = args[group_ids_pos]
# Only handle FalkorDB with multiple group_ids
if (
hasattr(self, 'clients')
and hasattr(self.clients, 'driver')
and self.clients.driver.provider == GraphProvider.FALKORDB
and group_ids
and len(group_ids) > 1
):
# Execute for each group_id concurrently
driver = self.clients.driver
async def execute_for_group(gid: str):
# Remove group_ids from args if it was passed positionally
filtered_args = list(args)
if group_ids_pos is not None and len(args) > group_ids_pos:
filtered_args.pop(group_ids_pos)
return await func(
self,
*filtered_args,
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
)
results = await semaphore_gather(
*[execute_for_group(gid) for gid in group_ids],
max_coroutines=getattr(self, 'max_coroutines', None),
)
# Merge results based on type
if isinstance(results[0], SearchResults):
return SearchResults.merge(results)
elif isinstance(results[0], list):
return [item for result in results for item in result]
elif isinstance(results[0], tuple):
# Handle tuple outputs (like build_communities returning (nodes, edges))
merged_tuple = []
for i in range(len(results[0])):
component_results = [result[i] for result in results]
if isinstance(component_results[0], list):
merged_tuple.append(
[item for component in component_results for item in component]
)
else:
merged_tuple.append(component_results)
return tuple(merged_tuple)
else:
return results
# Normal execution
return await func(self, *args, **kwargs)
return wrapper # type: ignore
def get_parameter_position(func: Callable, param_name: str) -> int | None:
"""
Returns the positional index of a parameter in the function signature.
If the parameter is not found, returns None.
"""
sig = inspect.signature(func)
for idx, (name, _param) in enumerate(sig.parameters.items()):
if name == param_name:
return idx
return None

View file

@ -76,6 +76,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
default_group_id: str = ''
search_interface: SearchInterface | None = None
graph_operations_interface: GraphOperationsInterface | None = None
@ -105,6 +106,14 @@ class GraphDriver(ABC):
return cloned
@abstractmethod
async def build_indices_and_constraints(self, delete_existing: bool = False):
raise NotImplementedError()
def clone(self, database: str) -> 'GraphDriver':
"""Clone the driver with a different database or graph name."""
return self
def build_fulltext_query(
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
) -> str:

View file

@ -34,6 +34,7 @@ else:
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
logger = logging.getLogger(__name__)
@ -112,6 +113,8 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB
default_group_id: str = '\\_'
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
aoss_client: None = None
def __init__(
@ -129,9 +132,16 @@ class FalkorDriver(GraphDriver):
FalkorDB is a multi-tenant graph database.
To connect, provide the host and port.
The default parameters assume a local (on-premises) FalkorDB instance.
Args:
host (str): The host where FalkorDB is running.
port (int): The port on which FalkorDB is listening.
username (str | None): The username for authentication (if required).
password (str | None): The password for authentication (if required).
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
database (str): The name of the database to connect to. Defaults to 'default_db'.
"""
super().__init__()
self._database = database
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
@ -139,7 +149,15 @@ class FalkorDriver(GraphDriver):
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
# Schedule the indices and constraints to be built
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
@ -224,12 +242,25 @@ class FalkorDriver(GraphDriver):
if drop_tasks:
await asyncio.gather(*drop_tasks)
async def build_indices_and_constraints(self, delete_existing=False):
if delete_existing:
await self.delete_all_indexes()
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
for query in index_queries:
await self.execute_query(query)
def clone(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = FalkorDriver(falkor_db=self.client, database=database)
if database == self._database:
cloned = self
elif database == self.default_group_id:
cloned = FalkorDriver(falkor_db=self.client)
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned

View file

@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
default_group_id: str = ''
def __init__(
self,
@ -43,6 +46,18 @@ class Neo4jDriver(GraphDriver):
)
self._database = database
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
self.aoss_client = None
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
@ -73,6 +88,25 @@ class Neo4jDriver(GraphDriver):
'CALL db.indexes() YIELD name DROP INDEX name',
)
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)
async def health_check(self) -> None:
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
try:

View file

@ -24,6 +24,7 @@ from typing_extensions import LiteralString
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.decorators import handle_multiple_group_ids
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import (
@ -87,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
@ -340,18 +340,17 @@ class Graphiti:
-----
This method should typically be called once during the initial setup of the
knowledge graph or when updating the database schema. It uses the
`build_indices_and_constraints` function from the
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
driver's `build_indices_and_constraints` method to perform
the actual database operations.
The specific indices and constraints created depend on the implementation
of the `build_indices_and_constraints` function. Refer to that function's
documentation for details on the exact database schema modifications.
of the driver's `build_indices_and_constraints` method. Refer to the specific
driver documentation for details on the exact database schema modifications.
Caution: Running this method on a large existing database may take some time
and could impact database performance during execution.
"""
await build_indices_and_constraints(self.driver, delete_existing)
await self.driver.build_indices_and_constraints(delete_existing)
async def _extract_and_resolve_nodes(
self,
@ -574,12 +573,14 @@ class Graphiti:
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
@handle_multiple_group_ids
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None,
source: EpisodeType | None = None,
driver: GraphDriver | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -606,7 +607,10 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
if driver is None:
driver = self.clients.driver
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
async def add_episode(
self,
@ -683,11 +687,18 @@ class Graphiti:
now = utc_now()
validate_entity_types(entity_types)
validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
if group_id is None:
# if group_id is None, use the default group id by the provider
# and the preset database name will be used
group_id = get_default_group_id(self.driver.provider)
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
with self.tracer.start_span('add_episode') as span:
try:
@ -865,8 +876,14 @@ class Graphiti:
now = utc_now()
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_group_id(group_id)
if group_id is None:
group_id = get_default_group_id(self.driver.provider)
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
# Create default edge type map
edge_type_map_default = (
@ -993,21 +1010,25 @@ class Graphiti:
bulk_span.record_exception(e)
raise e
@handle_multiple_group_ids
async def build_communities(
self, group_ids: list[str] | None = None
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
----------
query : list[str] | None
group_ids : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
"""
if driver is None:
driver = self.clients.driver
# Clear existing communities
await remove_communities(self.driver)
await remove_communities(driver)
community_nodes, community_edges = await build_communities(
self.driver, self.llm_client, group_ids
driver, self.llm_client, group_ids
)
await semaphore_gather(
@ -1016,16 +1037,17 @@ class Graphiti:
)
await semaphore_gather(
*[node.save(self.driver) for node in community_nodes],
*[node.save(driver) for node in community_nodes],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.save(self.driver) for edge in community_edges],
*[edge.save(driver) for edge in community_edges],
max_coroutines=self.max_coroutines,
)
return community_nodes, community_edges
@handle_multiple_group_ids
async def search(
self,
query: str,
@ -1033,6 +1055,7 @@ class Graphiti:
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None,
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
@ -1079,7 +1102,8 @@ class Graphiti:
group_ids,
search_config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
driver=driver,
center_node_uuid=center_node_uuid,
)
).edges
@ -1099,6 +1123,7 @@ class Graphiti:
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
)
@handle_multiple_group_ids
async def search_(
self,
query: str,
@ -1107,6 +1132,7 @@ class Graphiti:
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None,
) -> SearchResults:
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@ -1123,6 +1149,7 @@ class Graphiti:
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
bfs_origin_node_uuids,
driver=driver,
)
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:

View file

@ -74,10 +74,11 @@ async def search(
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None,
driver: GraphDriver | None = None,
) -> SearchResults:
start = time()
driver = clients.driver
driver = driver or clients.driver
embedder = clients.embedder
cross_encoder = clients.cross_encoder

View file

@ -127,3 +127,34 @@ class SearchResults(BaseModel):
episode_reranker_scores: list[float] = Field(default_factory=list)
communities: list[CommunityNode] = Field(default_factory=list)
community_reranker_scores: list[float] = Field(default_factory=list)
@classmethod
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
"""
Merge multiple SearchResults objects into a single SearchResults object.
Parameters
----------
results_list : list[SearchResults]
List of SearchResults objects to merge
Returns
-------
SearchResults
A single SearchResults object containing all results
"""
if not results_list:
return cls()
merged = cls()
for result in results_list:
merged.edges.extend(result.edges)
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
merged.nodes.extend(result.nodes)
merged.node_reranker_scores.extend(result.node_reranker_scores)
merged.episodes.extend(result.episodes)
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
merged.communities.extend(result.communities)
merged.community_reranker_scores.extend(result.community_reranker_scores)
return merged

View file

@ -20,8 +20,6 @@ from datetime import datetime
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import (
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
@ -33,46 +31,6 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if delete_existing:
await driver.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(driver.provider)
# Don't create fulltext indices if search_interface is being used
if not driver.search_interface:
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
# Skip creating fulltext indices if they already exist. Need to do this manually
# until Kuzu supports `IF NOT EXISTS` for indices.
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
if len(result) > 0:
fulltext_indices = []
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
if len(result) == 0:
fulltext_indices.insert(
0,
"""
INSTALL fts;
LOAD fts;
""",
)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
driver.execute_query(
query,
)
for query in index_queries
]
)
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async with driver.session() as session: