graph-per-graphid

This commit is contained in:
Gal Shubeli 2025-08-13 18:21:53 +03:00
parent 4f8eb310f2
commit bbf9cc6172
11 changed files with 279 additions and 87 deletions

View file

@ -78,8 +78,6 @@ async def main():
graphiti = Graphiti(graph_driver=falkor_driver)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES

View file

@ -67,8 +67,6 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES

View file

@ -0,0 +1,77 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from typing import Any, Awaitable, Callable, TypeVar
from graphiti_core.driver.driver import GraphProvider
from graphiti_core.helpers import semaphore_gather
from graphiti_core.search.search_config import SearchResults
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
def handle_multiple_group_ids(func: F) -> F:
"""
Decorator for FalkorDB methods that need to handle multiple group_ids.
Runs the function for each group_id separately and merges results.
"""
@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids = kwargs.get('group_ids')
# Only handle FalkorDB with multiple group_ids
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
self.clients.driver.provider == GraphProvider.FALKORDB and
group_ids and len(group_ids) > 1):
# Execute for each group_id concurrently
driver = self.clients.driver
async def execute_for_group(gid: str):
return await func(
self,
*args,
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
)
results = await semaphore_gather(
*[execute_for_group(gid) for gid in group_ids],
max_coroutines=getattr(self, 'max_coroutines', None)
)
# Merge results based on type
if isinstance(results[0], SearchResults):
return SearchResults.merge(results)
elif isinstance(results[0], list):
return [item for result in results for item in result]
elif isinstance(results[0], tuple):
# Handle tuple outputs (like build_communities returning (nodes, edges))
merged_tuple = []
for i in range(len(results[0])):
component_results = [result[i] for result in results]
if isinstance(component_results[0], list):
merged_tuple.append([item for component in component_results for item in component])
else:
merged_tuple.append(component_results)
return tuple(merged_tuple)
else:
return results
# Normal execution
return await func(self, *args, **kwargs)
return wrapper # type: ignore

View file

@ -57,6 +57,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
default_group_id: str = ''
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -74,12 +75,10 @@ class GraphDriver(ABC):
def delete_all_indexes(self) -> Coroutine:
raise NotImplementedError()
def with_database(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = copy.copy(self)
cloned._database = database
@abstractmethod
async def build_indices_and_constraints(self, delete_existing: bool = False):
raise NotImplementedError()
return cloned
def clone(self, database: str) -> 'GraphDriver':
"""Clone the driver with a different database or graph name."""
return self

View file

@ -18,6 +18,8 @@ import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing_extensions import LiteralString
if TYPE_CHECKING:
from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB
@ -33,6 +35,8 @@ else:
) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
@ -72,6 +76,8 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB
default_group_id: str = '\\_'
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
def __init__(
self,
@ -80,7 +86,7 @@ class FalkorDriver(GraphDriver):
username: str | None = None,
password: str | None = None,
falkor_db: FalkorDB | None = None,
database: str = 'default_db',
database: str = '\\_',
):
"""
Initialize the FalkorDB driver.
@ -98,7 +104,16 @@ class FalkorDriver(GraphDriver):
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
@ -152,8 +167,64 @@ class FalkorDriver(GraphDriver):
await self.client.connection.close()
async def delete_all_indexes(self) -> None:
await self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
from collections import defaultdict
result = await self.execute_query('CALL db.indexes()')
if result is None:
return
records, _, _ = result
# Organize indexes by type and label
range_indexes = defaultdict(list)
fulltext_indexes = defaultdict(list)
entity_types = {}
for record in records:
label = record['label']
entity_types[label] = record['entitytype']
for field_name, index_type in record['types'].items():
if 'RANGE' in index_type:
range_indexes[label].append(field_name)
if 'FULLTEXT' in index_type:
fulltext_indexes[label].append(field_name)
# Drop all range indexes
for label, fields in range_indexes.items():
for field in fields:
await self.execute_query(f'DROP INDEX ON :{label}({field})')
# Drop all fulltext indexes
for label, fields in fulltext_indexes.items():
entity_type = entity_types[label]
for field in fields:
if entity_type == 'NODE':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
)
elif entity_type == 'RELATIONSHIP':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
)
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)
def clone(self, database: str) -> 'GraphDriver':
@ -161,8 +232,12 @@ class FalkorDriver(GraphDriver):
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = FalkorDriver(falkor_db=self.client, database=database)
if database == self._database:
cloned = self
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned

View file

@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
default_group_id: str = ''
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
super().__init__()
@ -36,6 +39,17 @@ class Neo4jDriver(GraphDriver):
auth=(user or '', password or ''),
)
self._database = database
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs.
@ -64,3 +78,22 @@ class Neo4jDriver(GraphDriver):
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)

View file

@ -24,6 +24,7 @@ from typing_extensions import LiteralString
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.decorators import handle_multiple_group_ids
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import (
@ -35,7 +36,6 @@ from graphiti_core.edges import (
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
get_default_group_id,
semaphore_gather,
validate_excluded_entity_types,
validate_group_id,
@ -87,7 +87,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
@ -320,25 +319,26 @@ class Graphiti:
-----
This method should typically be called once during the initial setup of the
knowledge graph or when updating the database schema. It uses the
`build_indices_and_constraints` function from the
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
driver's `build_indices_and_constraints` method to perform
the actual database operations.
The specific indices and constraints created depend on the implementation
of the `build_indices_and_constraints` function. Refer to that function's
documentation for details on the exact database schema modifications.
of the driver's `build_indices_and_constraints` method. Refer to the specific
driver documentation for details on the exact database schema modifications.
Caution: Running this method on a large existing database may take some time
and could impact database performance during execution.
"""
await build_indices_and_constraints(self.driver, delete_existing)
await self.driver.build_indices_and_constraints(delete_existing)
@handle_multiple_group_ids
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None,
source: EpisodeType | None = None,
driver: GraphDriver | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -365,7 +365,10 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module.
"""
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
if driver is None:
driver = self.clients.driver
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
async def add_episode(
self,
@ -442,12 +445,18 @@ class Graphiti:
start = time()
now = utc_now()
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_entity_types(entity_types)
if group_id is None:
# if group_id is None, use the default group id by the provider
group_id = self.driver.default_group_id
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
validate_entity_types(entity_types)
validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
previous_episodes = (
await self.retrieve_episodes(
@ -620,9 +629,15 @@ class Graphiti:
start = time()
now = utc_now()
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_group_id(group_id)
if group_id is None:
# if group_id is None, use the default group id by the provider
group_id = self.driver.default_group_id
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
# Create default edge type map
edge_type_map_default = (
@ -850,21 +865,26 @@ class Graphiti:
except Exception as e:
raise e
@handle_multiple_group_ids
async def build_communities(
self, group_ids: list[str] | None = None
self, group_ids: list[str] | None = None,
driver: GraphDriver | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
----------
query : list[str] | None
group_ids : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
"""
if driver is None:
driver = self.clients.driver
# Clear existing communities
await remove_communities(self.driver)
await remove_communities(driver)
community_nodes, community_edges = await build_communities(
self.driver, self.llm_client, group_ids
driver, self.llm_client, group_ids
)
await semaphore_gather(
@ -873,16 +893,17 @@ class Graphiti:
)
await semaphore_gather(
*[node.save(self.driver) for node in community_nodes],
*[node.save(driver) for node in community_nodes],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.save(self.driver) for edge in community_edges],
*[edge.save(driver) for edge in community_edges],
max_coroutines=self.max_coroutines,
)
return community_nodes, community_edges
@handle_multiple_group_ids
async def search(
self,
query: str,
@ -890,6 +911,7 @@ class Graphiti:
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
@ -936,7 +958,8 @@ class Graphiti:
group_ids,
search_config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
driver=driver,
center_node_uuid=center_node_uuid
)
).edges
@ -956,6 +979,7 @@ class Graphiti:
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
)
@handle_multiple_group_ids
async def search_(
self,
query: str,
@ -964,6 +988,7 @@ class Graphiti:
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
) -> SearchResults:
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@ -980,6 +1005,7 @@ class Graphiti:
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
bfs_origin_node_uuids,
driver=driver
)
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:

View file

@ -53,17 +53,6 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
)
def get_default_group_id(provider: GraphProvider) -> str:
"""
This function differentiates the default group id based on the database type.
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
"""
if provider == GraphProvider.FALKORDB:
return '_'
else:
return ''
def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /

View file

@ -72,10 +72,11 @@ async def search(
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None,
driver: GraphDriver | None = None,
) -> SearchResults:
start = time()
driver = clients.driver
driver = driver or clients.driver
embedder = clients.embedder
cross_encoder = clients.cross_encoder

View file

@ -127,3 +127,34 @@ class SearchResults(BaseModel):
episode_reranker_scores: list[float] = Field(default_factory=list)
communities: list[CommunityNode] = Field(default_factory=list)
community_reranker_scores: list[float] = Field(default_factory=list)
@classmethod
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
"""
Merge multiple SearchResults objects into a single SearchResults object.
Parameters
----------
results_list : list[SearchResults]
List of SearchResults objects to merge
Returns
-------
SearchResults
A single SearchResults object containing all results
"""
if not results_list:
return cls()
merged = cls()
for result in results_list:
merged.edges.extend(result.edges)
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
merged.nodes.extend(result.nodes)
merged.node_reranker_scores.extend(result.node_reranker_scores)
merged.episodes.extend(result.episodes)
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
merged.communities.extend(result.communities)
merged.community_reranker_scores.extend(result.community_reranker_scores)
return merged

View file

@ -20,8 +20,6 @@ from datetime import datetime
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
@ -30,39 +28,6 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if delete_existing:
records, _, _ = await driver.execute_query(
"""
SHOW INDEXES YIELD name
""",
)
index_names = [record['name'] for record in records]
await semaphore_gather(
*[
driver.execute_query(
"""DROP INDEX $name""",
name=name,
)
for name in index_names
]
)
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
driver.execute_query(
query,
)
for query in index_queries
]
)
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async with driver.session() as session: