graph-per-graphid

This commit is contained in:
Gal Shubeli 2025-08-13 18:21:53 +03:00
parent 4f8eb310f2
commit bbf9cc6172
11 changed files with 279 additions and 87 deletions

View file

@ -78,8 +78,6 @@ async def main():
graphiti = Graphiti(graph_driver=falkor_driver) graphiti = Graphiti(graph_driver=falkor_driver)
try: try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
################################################# #################################################
# ADDING EPISODES # ADDING EPISODES

View file

@ -67,8 +67,6 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try: try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
################################################# #################################################
# ADDING EPISODES # ADDING EPISODES

View file

@ -0,0 +1,77 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from typing import Any, Awaitable, Callable, TypeVar
from graphiti_core.driver.driver import GraphProvider
from graphiti_core.helpers import semaphore_gather
from graphiti_core.search.search_config import SearchResults
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
def handle_multiple_group_ids(func: F) -> F:
"""
Decorator for FalkorDB methods that need to handle multiple group_ids.
Runs the function for each group_id separately and merges results.
"""
@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids = kwargs.get('group_ids')
# Only handle FalkorDB with multiple group_ids
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
self.clients.driver.provider == GraphProvider.FALKORDB and
group_ids and len(group_ids) > 1):
# Execute for each group_id concurrently
driver = self.clients.driver
async def execute_for_group(gid: str):
return await func(
self,
*args,
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
)
results = await semaphore_gather(
*[execute_for_group(gid) for gid in group_ids],
max_coroutines=getattr(self, 'max_coroutines', None)
)
# Merge results based on type
if isinstance(results[0], SearchResults):
return SearchResults.merge(results)
elif isinstance(results[0], list):
return [item for result in results for item in result]
elif isinstance(results[0], tuple):
# Handle tuple outputs (like build_communities returning (nodes, edges))
merged_tuple = []
for i in range(len(results[0])):
component_results = [result[i] for result in results]
if isinstance(component_results[0], list):
merged_tuple.append([item for component in component_results for item in component])
else:
merged_tuple.append(component_results)
return tuple(merged_tuple)
else:
return results
# Normal execution
return await func(self, *args, **kwargs)
return wrapper # type: ignore

View file

@ -57,6 +57,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries '' # Neo4j (default) syntax does not require a prefix for fulltext queries
) )
_database: str _database: str
default_group_id: str = ''
@abstractmethod @abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -74,12 +75,10 @@ class GraphDriver(ABC):
def delete_all_indexes(self) -> Coroutine: def delete_all_indexes(self) -> Coroutine:
raise NotImplementedError() raise NotImplementedError()
def with_database(self, database: str) -> 'GraphDriver': @abstractmethod
""" async def build_indices_and_constraints(self, delete_existing: bool = False):
Returns a shallow copy of this driver with a different default database. raise NotImplementedError()
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = copy.copy(self)
cloned._database = database
return cloned def clone(self, database: str) -> 'GraphDriver':
"""Clone the driver with a different database or graph name."""
return self

View file

@ -18,6 +18,8 @@ import logging
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from typing_extensions import LiteralString
if TYPE_CHECKING: if TYPE_CHECKING:
from falkordb import Graph as FalkorGraph from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB from falkordb.asyncio import FalkorDB
@ -33,6 +35,8 @@ else:
) from None ) from None
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,6 +76,8 @@ class FalkorDriverSession(GraphDriverSession):
class FalkorDriver(GraphDriver): class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB provider = GraphProvider.FALKORDB
default_group_id: str = '\\_'
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
def __init__( def __init__(
self, self,
@ -80,7 +86,7 @@ class FalkorDriver(GraphDriver):
username: str | None = None, username: str | None = None,
password: str | None = None, password: str | None = None,
falkor_db: FalkorDB | None = None, falkor_db: FalkorDB | None = None,
database: str = 'default_db', database: str = '\\_',
): ):
""" """
Initialize the FalkorDB driver. Initialize the FalkorDB driver.
@ -98,7 +104,16 @@ class FalkorDriver(GraphDriver):
else: else:
self.client = FalkorDB(host=host, port=port, username=username, password=password) self.client = FalkorDB(host=host, port=port, username=username, password=password)
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/ # Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
def _get_graph(self, graph_name: str | None) -> FalkorGraph: def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db" # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
@ -152,8 +167,64 @@ class FalkorDriver(GraphDriver):
await self.client.connection.close() await self.client.connection.close()
async def delete_all_indexes(self) -> None: async def delete_all_indexes(self) -> None:
await self.execute_query( from collections import defaultdict
'CALL db.indexes() YIELD name DROP INDEX name',
result = await self.execute_query('CALL db.indexes()')
if result is None:
return
records, _, _ = result
# Organize indexes by type and label
range_indexes = defaultdict(list)
fulltext_indexes = defaultdict(list)
entity_types = {}
for record in records:
label = record['label']
entity_types[label] = record['entitytype']
for field_name, index_type in record['types'].items():
if 'RANGE' in index_type:
range_indexes[label].append(field_name)
if 'FULLTEXT' in index_type:
fulltext_indexes[label].append(field_name)
# Drop all range indexes
for label, fields in range_indexes.items():
for field in fields:
await self.execute_query(f'DROP INDEX ON :{label}({field})')
# Drop all fulltext indexes
for label, fields in fulltext_indexes.items():
entity_type = entity_types[label]
for field in fields:
if entity_type == 'NODE':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
)
elif entity_type == 'RELATIONSHIP':
await self.execute_query(
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
)
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
) )
def clone(self, database: str) -> 'GraphDriver': def clone(self, database: str) -> 'GraphDriver':
@ -161,7 +232,11 @@ class FalkorDriver(GraphDriver):
Returns a shallow copy of this driver with a different default database. Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j). Reuses the same connection (e.g. FalkorDB, Neo4j).
""" """
cloned = FalkorDriver(falkor_db=self.client, database=database) if database == self._database:
cloned = self
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned return cloned

View file

@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver): class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J provider = GraphProvider.NEO4J
default_group_id: str = ''
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
super().__init__() super().__init__()
@ -37,6 +40,17 @@ class Neo4jDriver(GraphDriver):
) )
self._database = database self._database = database
# Schedule the indices and constraints to be built
import asyncio
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult: async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
# Check if database_ is provided in kwargs. # Check if database_ is provided in kwargs.
# If not populated, set the value to retain backwards compatibility # If not populated, set the value to retain backwards compatibility
@ -64,3 +78,22 @@ class Neo4jDriver(GraphDriver):
return self.client.execute_query( return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name', 'CALL db.indexes() YIELD name DROP INDEX name',
) )
async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()
range_indices: list[LiteralString] = get_range_indices(self.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)

View file

@ -24,6 +24,7 @@ from typing_extensions import LiteralString
from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.decorators import handle_multiple_group_ids
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import ( from graphiti_core.edges import (
@ -35,7 +36,6 @@ from graphiti_core.edges import (
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import ( from graphiti_core.helpers import (
get_default_group_id,
semaphore_gather, semaphore_gather,
validate_excluded_entity_types, validate_excluded_entity_types,
validate_group_id, validate_group_id,
@ -87,7 +87,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN, EPISODE_WINDOW_LEN,
build_indices_and_constraints,
retrieve_episodes, retrieve_episodes,
) )
from graphiti_core.utils.maintenance.node_operations import ( from graphiti_core.utils.maintenance.node_operations import (
@ -320,25 +319,26 @@ class Graphiti:
----- -----
This method should typically be called once during the initial setup of the This method should typically be called once during the initial setup of the
knowledge graph or when updating the database schema. It uses the knowledge graph or when updating the database schema. It uses the
`build_indices_and_constraints` function from the driver's `build_indices_and_constraints` method to perform
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
the actual database operations. the actual database operations.
The specific indices and constraints created depend on the implementation The specific indices and constraints created depend on the implementation
of the `build_indices_and_constraints` function. Refer to that function's of the driver's `build_indices_and_constraints` method. Refer to the specific
documentation for details on the exact database schema modifications. driver documentation for details on the exact database schema modifications.
Caution: Running this method on a large existing database may take some time Caution: Running this method on a large existing database may take some time
and could impact database performance during execution. and could impact database performance during execution.
""" """
await build_indices_and_constraints(self.driver, delete_existing) await self.driver.build_indices_and_constraints(delete_existing)
@handle_multiple_group_ids
async def retrieve_episodes( async def retrieve_episodes(
self, self,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
source: EpisodeType | None = None, source: EpisodeType | None = None,
driver: GraphDriver | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -365,7 +365,10 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module. from the `graphiti_core.utils` module.
""" """
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) if driver is None:
driver = self.clients.driver
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
async def add_episode( async def add_episode(
self, self,
@ -442,12 +445,18 @@ class Graphiti:
start = time() start = time()
now = utc_now() now = utc_now()
# if group_id is None, use the default group id by the provider if group_id is None:
group_id = group_id or get_default_group_id(self.driver.provider) # if group_id is None, use the default group id by the provider
validate_entity_types(entity_types) group_id = self.driver.default_group_id
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
validate_entity_types(entity_types)
validate_excluded_entity_types(excluded_entity_types, entity_types) validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
previous_episodes = ( previous_episodes = (
await self.retrieve_episodes( await self.retrieve_episodes(
@ -620,9 +629,15 @@ class Graphiti:
start = time() start = time()
now = utc_now() now = utc_now()
# if group_id is None, use the default group id by the provider if group_id is None:
group_id = group_id or get_default_group_id(self.driver.provider) # if group_id is None, use the default group id by the provider
validate_group_id(group_id) group_id = self.driver.default_group_id
else:
validate_group_id(group_id)
if group_id != self.driver._database:
# if group_id is provided, use it as the database name
self.driver = self.driver.clone(database=group_id)
self.clients.driver = self.driver
# Create default edge type map # Create default edge type map
edge_type_map_default = ( edge_type_map_default = (
@ -850,21 +865,26 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
@handle_multiple_group_ids
async def build_communities( async def build_communities(
self, group_ids: list[str] | None = None self, group_ids: list[str] | None = None,
driver: GraphDriver | None = None
) -> tuple[list[CommunityNode], list[CommunityEdge]]: ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
""" """
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities. the content of these communities.
---------- ----------
query : list[str] | None group_ids : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
""" """
if driver is None:
driver = self.clients.driver
# Clear existing communities # Clear existing communities
await remove_communities(self.driver) await remove_communities(driver)
community_nodes, community_edges = await build_communities( community_nodes, community_edges = await build_communities(
self.driver, self.llm_client, group_ids driver, self.llm_client, group_ids
) )
await semaphore_gather( await semaphore_gather(
@ -873,16 +893,17 @@ class Graphiti:
) )
await semaphore_gather( await semaphore_gather(
*[node.save(self.driver) for node in community_nodes], *[node.save(driver) for node in community_nodes],
max_coroutines=self.max_coroutines, max_coroutines=self.max_coroutines,
) )
await semaphore_gather( await semaphore_gather(
*[edge.save(self.driver) for edge in community_edges], *[edge.save(driver) for edge in community_edges],
max_coroutines=self.max_coroutines, max_coroutines=self.max_coroutines,
) )
return community_nodes, community_edges return community_nodes, community_edges
@handle_multiple_group_ids
async def search( async def search(
self, self,
query: str, query: str,
@ -890,6 +911,7 @@ class Graphiti:
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT, num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None, search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
) -> list[EntityEdge]: ) -> list[EntityEdge]:
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -936,7 +958,8 @@ class Graphiti:
group_ids, group_ids,
search_config, search_config,
search_filter if search_filter is not None else SearchFilters(), search_filter if search_filter is not None else SearchFilters(),
center_node_uuid, driver=driver,
center_node_uuid=center_node_uuid
) )
).edges ).edges
@ -956,6 +979,7 @@ class Graphiti:
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
) )
@handle_multiple_group_ids
async def search_( async def search_(
self, self,
query: str, query: str,
@ -964,6 +988,7 @@ class Graphiti:
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None, search_filter: SearchFilters | None = None,
driver: GraphDriver | None = None
) -> SearchResults: ) -> SearchResults:
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@ -980,6 +1005,7 @@ class Graphiti:
search_filter if search_filter is not None else SearchFilters(), search_filter if search_filter is not None else SearchFilters(),
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids, bfs_origin_node_uuids,
driver=driver
) )
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults: async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:

View file

@ -53,17 +53,6 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
) )
def get_default_group_id(provider: GraphProvider) -> str:
"""
This function differentiates the default group id based on the database type.
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
"""
if provider == GraphProvider.FALKORDB:
return '_'
else:
return ''
def lucene_sanitize(query: str) -> str: def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene # Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /

View file

@ -72,10 +72,11 @@ async def search(
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None, bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None, query_vector: list[float] | None = None,
driver: GraphDriver | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
driver = clients.driver driver = driver or clients.driver
embedder = clients.embedder embedder = clients.embedder
cross_encoder = clients.cross_encoder cross_encoder = clients.cross_encoder

View file

@ -127,3 +127,34 @@ class SearchResults(BaseModel):
episode_reranker_scores: list[float] = Field(default_factory=list) episode_reranker_scores: list[float] = Field(default_factory=list)
communities: list[CommunityNode] = Field(default_factory=list) communities: list[CommunityNode] = Field(default_factory=list)
community_reranker_scores: list[float] = Field(default_factory=list) community_reranker_scores: list[float] = Field(default_factory=list)
@classmethod
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
"""
Merge multiple SearchResults objects into a single SearchResults object.
Parameters
----------
results_list : list[SearchResults]
List of SearchResults objects to merge
Returns
-------
SearchResults
A single SearchResults object containing all results
"""
if not results_list:
return cls()
merged = cls()
for result in results_list:
merged.edges.extend(result.edges)
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
merged.nodes.extend(result.nodes)
merged.node_reranker_scores.extend(result.node_reranker_scores)
merged.episodes.extend(result.episodes)
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
merged.communities.extend(result.communities)
merged.community_reranker_scores.extend(result.community_reranker_scores)
return merged

View file

@ -20,8 +20,6 @@ from datetime import datetime
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
@ -30,39 +28,6 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if delete_existing:
records, _, _ = await driver.execute_query(
"""
SHOW INDEXES YIELD name
""",
)
index_names = [record['name'] for record in records]
await semaphore_gather(
*[
driver.execute_query(
"""DROP INDEX $name""",
name=name,
)
for name in index_names
]
)
range_indices: list[LiteralString] = get_range_indices(driver.provider)
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
await semaphore_gather(
*[
driver.execute_query(
query,
)
for query in index_queries
]
)
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async with driver.session() as session: async with driver.session() as session: