[Improvement] Add GraphID isolation support for FalkorDB multi-tenant architecture (#835)
* Update node_db_queries.py * Update node_db_queries.py * graph-per-graphid * fix-groupid-usage * ruff-fix * rev-driver-changes * rm-un-changes * fix lint --------- Co-authored-by: Naseem Ali <34807727+Naseem77@users.noreply.github.com>
This commit is contained in:
parent
8d99984204
commit
c144ff5995
10 changed files with 267 additions and 72 deletions
|
|
@ -78,9 +78,6 @@ async def main():
|
|||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
|
|
|
|||
|
|
@ -67,9 +67,6 @@ async def main():
|
|||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
|
|
|
|||
110
graphiti_core/decorators.py
Normal file
110
graphiti_core/decorators.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
|
||||
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
|
||||
def handle_multiple_group_ids(func: F) -> F:
|
||||
"""
|
||||
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
||||
Runs the function for each group_id separately and merges results.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
||||
group_ids_pos = (
|
||||
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
|
||||
) # Adjust for zero-based index
|
||||
group_ids = kwargs.get('group_ids')
|
||||
|
||||
# If not in kwargs and position exists, get from args
|
||||
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
group_ids = args[group_ids_pos]
|
||||
|
||||
# Only handle FalkorDB with multiple group_ids
|
||||
if (
|
||||
hasattr(self, 'clients')
|
||||
and hasattr(self.clients, 'driver')
|
||||
and self.clients.driver.provider == GraphProvider.FALKORDB
|
||||
and group_ids
|
||||
and len(group_ids) > 1
|
||||
):
|
||||
# Execute for each group_id concurrently
|
||||
driver = self.clients.driver
|
||||
|
||||
async def execute_for_group(gid: str):
|
||||
# Remove group_ids from args if it was passed positionally
|
||||
filtered_args = list(args)
|
||||
if group_ids_pos is not None and len(args) > group_ids_pos:
|
||||
filtered_args.pop(group_ids_pos)
|
||||
|
||||
return await func(
|
||||
self,
|
||||
*filtered_args,
|
||||
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
|
||||
)
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[execute_for_group(gid) for gid in group_ids],
|
||||
max_coroutines=getattr(self, 'max_coroutines', None),
|
||||
)
|
||||
|
||||
# Merge results based on type
|
||||
if isinstance(results[0], SearchResults):
|
||||
return SearchResults.merge(results)
|
||||
elif isinstance(results[0], list):
|
||||
return [item for result in results for item in result]
|
||||
elif isinstance(results[0], tuple):
|
||||
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
||||
merged_tuple = []
|
||||
for i in range(len(results[0])):
|
||||
component_results = [result[i] for result in results]
|
||||
if isinstance(component_results[0], list):
|
||||
merged_tuple.append(
|
||||
[item for component in component_results for item in component]
|
||||
)
|
||||
else:
|
||||
merged_tuple.append(component_results)
|
||||
return tuple(merged_tuple)
|
||||
else:
|
||||
return results
|
||||
|
||||
# Normal execution
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
||||
"""
|
||||
Returns the positional index of a parameter in the function signature.
|
||||
If the parameter is not found, returns None.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
for idx, (name, _param) in enumerate(sig.parameters.items()):
|
||||
if name == param_name:
|
||||
return idx
|
||||
return None
|
||||
|
|
@ -76,6 +76,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
default_group_id: str = ''
|
||||
search_interface: SearchInterface | None = None
|
||||
graph_operations_interface: GraphOperationsInterface | None = None
|
||||
|
||||
|
|
@ -105,6 +106,14 @@ class GraphDriver(ABC):
|
|||
|
||||
return cloned
|
||||
|
||||
@abstractmethod
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
raise NotImplementedError()
|
||||
|
||||
def clone(self, database: str) -> 'GraphDriver':
|
||||
"""Clone the driver with a different database or graph name."""
|
||||
return self
|
||||
|
||||
def build_fulltext_query(
|
||||
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ else:
|
|||
) from None
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -112,6 +113,8 @@ class FalkorDriverSession(GraphDriverSession):
|
|||
|
||||
class FalkorDriver(GraphDriver):
|
||||
provider = GraphProvider.FALKORDB
|
||||
default_group_id: str = '\\_'
|
||||
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
|
||||
aoss_client: None = None
|
||||
|
||||
def __init__(
|
||||
|
|
@ -129,9 +132,16 @@ class FalkorDriver(GraphDriver):
|
|||
FalkorDB is a multi-tenant graph database.
|
||||
To connect, provide the host and port.
|
||||
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||
|
||||
Args:
|
||||
host (str): The host where FalkorDB is running.
|
||||
port (int): The port on which FalkorDB is listening.
|
||||
username (str | None): The username for authentication (if required).
|
||||
password (str | None): The password for authentication (if required).
|
||||
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
|
||||
database (str): The name of the database to connect to. Defaults to 'default_db'.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._database = database
|
||||
if falkor_db is not None:
|
||||
# If a FalkorDB instance is provided, use it directly
|
||||
|
|
@ -139,7 +149,15 @@ class FalkorDriver(GraphDriver):
|
|||
else:
|
||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
|
||||
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
||||
# Schedule the indices and constraints to be built
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# Schedule the build_indices_and_constraints to run
|
||||
loop.create_task(self.build_indices_and_constraints())
|
||||
except RuntimeError:
|
||||
# No event loop running, this will be handled later
|
||||
pass
|
||||
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||
|
|
@ -224,12 +242,25 @@ class FalkorDriver(GraphDriver):
|
|||
if drop_tasks:
|
||||
await asyncio.gather(*drop_tasks)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing=False):
|
||||
if delete_existing:
|
||||
await self.delete_all_indexes()
|
||||
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
|
||||
for query in index_queries:
|
||||
await self.execute_query(query)
|
||||
|
||||
def clone(self, database: str) -> 'GraphDriver':
|
||||
"""
|
||||
Returns a shallow copy of this driver with a different default database.
|
||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
||||
"""
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
if database == self._database:
|
||||
cloned = self
|
||||
elif database == self.default_group_id:
|
||||
cloned = FalkorDriver(falkor_db=self.client)
|
||||
else:
|
||||
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
return cloned
|
||||
|
||||
|
|
|
|||
|
|
@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
default_group_id: str = ''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -43,6 +46,18 @@ class Neo4jDriver(GraphDriver):
|
|||
)
|
||||
self._database = database
|
||||
|
||||
# Schedule the indices and constraints to be built
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# Schedule the build_indices_and_constraints to run
|
||||
loop.create_task(self.build_indices_and_constraints())
|
||||
except RuntimeError:
|
||||
# No event loop running, this will be handled later
|
||||
pass
|
||||
|
||||
self.aoss_client = None
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
|
|
@ -73,6 +88,25 @@ class Neo4jDriver(GraphDriver):
|
|||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
await self.delete_all_indexes()
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
self.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
||||
async def health_check(self) -> None:
|
||||
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from typing_extensions import LiteralString
|
|||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.decorators import handle_multiple_group_ids
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.edges import (
|
||||
|
|
@ -87,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
|
|
@ -340,18 +340,17 @@ class Graphiti:
|
|||
-----
|
||||
This method should typically be called once during the initial setup of the
|
||||
knowledge graph or when updating the database schema. It uses the
|
||||
`build_indices_and_constraints` function from the
|
||||
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
||||
driver's `build_indices_and_constraints` method to perform
|
||||
the actual database operations.
|
||||
|
||||
The specific indices and constraints created depend on the implementation
|
||||
of the `build_indices_and_constraints` function. Refer to that function's
|
||||
documentation for details on the exact database schema modifications.
|
||||
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
||||
driver documentation for details on the exact database schema modifications.
|
||||
|
||||
Caution: Running this method on a large existing database may take some time
|
||||
and could impact database performance during execution.
|
||||
"""
|
||||
await build_indices_and_constraints(self.driver, delete_existing)
|
||||
await self.driver.build_indices_and_constraints(delete_existing)
|
||||
|
||||
async def _extract_and_resolve_nodes(
|
||||
self,
|
||||
|
|
@ -574,12 +573,14 @@ class Graphiti:
|
|||
|
||||
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -606,7 +607,10 @@ class Graphiti:
|
|||
The actual retrieval is performed by the `retrieve_episodes` function
|
||||
from the `graphiti_core.utils` module.
|
||||
"""
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
||||
if driver is None:
|
||||
driver = self.clients.driver
|
||||
|
||||
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
|
|
@ -683,11 +687,18 @@ class Graphiti:
|
|||
now = utc_now()
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
|
||||
if group_id is None:
|
||||
# if group_id is None, use the default group id by the provider
|
||||
# and the preset database name will be used
|
||||
group_id = get_default_group_id(self.driver.provider)
|
||||
else:
|
||||
validate_group_id(group_id)
|
||||
if group_id != self.driver._database:
|
||||
# if group_id is provided, use it as the database name
|
||||
self.driver = self.driver.clone(database=group_id)
|
||||
self.clients.driver = self.driver
|
||||
|
||||
with self.tracer.start_span('add_episode') as span:
|
||||
try:
|
||||
|
|
@ -865,8 +876,14 @@ class Graphiti:
|
|||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_group_id(group_id)
|
||||
if group_id is None:
|
||||
group_id = get_default_group_id(self.driver.provider)
|
||||
else:
|
||||
validate_group_id(group_id)
|
||||
if group_id != self.driver._database:
|
||||
# if group_id is provided, use it as the database name
|
||||
self.driver = self.driver.clone(database=group_id)
|
||||
self.clients.driver = self.driver
|
||||
|
||||
# Create default edge type map
|
||||
edge_type_map_default = (
|
||||
|
|
@ -993,21 +1010,25 @@ class Graphiti:
|
|||
bulk_span.record_exception(e)
|
||||
raise e
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def build_communities(
|
||||
self, group_ids: list[str] | None = None
|
||||
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
"""
|
||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||
the content of these communities.
|
||||
----------
|
||||
query : list[str] | None
|
||||
group_ids : list[str] | None
|
||||
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
||||
"""
|
||||
if driver is None:
|
||||
driver = self.clients.driver
|
||||
|
||||
# Clear existing communities
|
||||
await remove_communities(self.driver)
|
||||
await remove_communities(driver)
|
||||
|
||||
community_nodes, community_edges = await build_communities(
|
||||
self.driver, self.llm_client, group_ids
|
||||
driver, self.llm_client, group_ids
|
||||
)
|
||||
|
||||
await semaphore_gather(
|
||||
|
|
@ -1016,16 +1037,17 @@ class Graphiti:
|
|||
)
|
||||
|
||||
await semaphore_gather(
|
||||
*[node.save(self.driver) for node in community_nodes],
|
||||
*[node.save(driver) for node in community_nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in community_edges],
|
||||
*[edge.save(driver) for edge in community_edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
return community_nodes, community_edges
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -1033,6 +1055,7 @@ class Graphiti:
|
|||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
|
@ -1079,7 +1102,8 @@ class Graphiti:
|
|||
group_ids,
|
||||
search_config,
|
||||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
driver=driver,
|
||||
center_node_uuid=center_node_uuid,
|
||||
)
|
||||
).edges
|
||||
|
||||
|
|
@ -1099,6 +1123,7 @@ class Graphiti:
|
|||
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
||||
)
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def search_(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -1107,6 +1132,7 @@ class Graphiti:
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> SearchResults:
|
||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||
|
|
@ -1123,6 +1149,7 @@ class Graphiti:
|
|||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
driver=driver,
|
||||
)
|
||||
|
||||
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
||||
|
|
|
|||
|
|
@ -74,10 +74,11 @@ async def search(
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
query_vector: list[float] | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
driver = clients.driver
|
||||
driver = driver or clients.driver
|
||||
embedder = clients.embedder
|
||||
cross_encoder = clients.cross_encoder
|
||||
|
||||
|
|
|
|||
|
|
@ -127,3 +127,34 @@ class SearchResults(BaseModel):
|
|||
episode_reranker_scores: list[float] = Field(default_factory=list)
|
||||
communities: list[CommunityNode] = Field(default_factory=list)
|
||||
community_reranker_scores: list[float] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
||||
"""
|
||||
Merge multiple SearchResults objects into a single SearchResults object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
results_list : list[SearchResults]
|
||||
List of SearchResults objects to merge
|
||||
|
||||
Returns
|
||||
-------
|
||||
SearchResults
|
||||
A single SearchResults object containing all results
|
||||
"""
|
||||
if not results_list:
|
||||
return cls()
|
||||
|
||||
merged = cls()
|
||||
for result in results_list:
|
||||
merged.edges.extend(result.edges)
|
||||
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
||||
merged.nodes.extend(result.nodes)
|
||||
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
||||
merged.episodes.extend(result.episodes)
|
||||
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
||||
merged.communities.extend(result.communities)
|
||||
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
||||
|
||||
return merged
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ from datetime import datetime
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||
|
|
@ -33,46 +31,6 @@ EPISODE_WINDOW_LEN = 3
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
await driver.delete_all_indexes()
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
# Don't create fulltext indices if search_interface is being used
|
||||
if not driver.search_interface:
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
# Skip creating fulltext indices if they already exist. Need to do this manually
|
||||
# until Kuzu supports `IF NOT EXISTS` for indices.
|
||||
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
|
||||
if len(result) > 0:
|
||||
fulltext_indices = []
|
||||
|
||||
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
|
||||
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
|
||||
if len(result) == 0:
|
||||
fulltext_indices.insert(
|
||||
0,
|
||||
"""
|
||||
INSTALL fts;
|
||||
LOAD fts;
|
||||
""",
|
||||
)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||
async with driver.session() as session:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue