graph-per-graphid
This commit is contained in:
parent
4f8eb310f2
commit
bbf9cc6172
11 changed files with 279 additions and 87 deletions
|
|
@ -78,8 +78,6 @@ async def main():
|
|||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
|
|
|
|||
|
|
@ -67,8 +67,6 @@ async def main():
|
|||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
|
|
|
|||
77
graphiti_core/decorators.py
Normal file
77
graphiti_core/decorators.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
|
||||
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
|
||||
def handle_multiple_group_ids(func: F) -> F:
|
||||
"""
|
||||
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
||||
Runs the function for each group_id separately and merges results.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
group_ids = kwargs.get('group_ids')
|
||||
|
||||
# Only handle FalkorDB with multiple group_ids
|
||||
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
||||
self.clients.driver.provider == GraphProvider.FALKORDB and
|
||||
group_ids and len(group_ids) > 1):
|
||||
|
||||
# Execute for each group_id concurrently
|
||||
driver = self.clients.driver
|
||||
|
||||
async def execute_for_group(gid: str):
|
||||
return await func(
|
||||
self,
|
||||
*args,
|
||||
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
||||
)
|
||||
|
||||
results = await semaphore_gather(
|
||||
*[execute_for_group(gid) for gid in group_ids],
|
||||
max_coroutines=getattr(self, 'max_coroutines', None)
|
||||
)
|
||||
|
||||
# Merge results based on type
|
||||
if isinstance(results[0], SearchResults):
|
||||
return SearchResults.merge(results)
|
||||
elif isinstance(results[0], list):
|
||||
return [item for result in results for item in result]
|
||||
elif isinstance(results[0], tuple):
|
||||
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
||||
merged_tuple = []
|
||||
for i in range(len(results[0])):
|
||||
component_results = [result[i] for result in results]
|
||||
if isinstance(component_results[0], list):
|
||||
merged_tuple.append([item for component in component_results for item in component])
|
||||
else:
|
||||
merged_tuple.append(component_results)
|
||||
return tuple(merged_tuple)
|
||||
else:
|
||||
return results
|
||||
|
||||
# Normal execution
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
|
@ -57,6 +57,7 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
default_group_id: str = ''
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -74,12 +75,10 @@ class GraphDriver(ABC):
|
|||
def delete_all_indexes(self) -> Coroutine:
|
||||
raise NotImplementedError()
|
||||
|
||||
def with_database(self, database: str) -> 'GraphDriver':
|
||||
"""
|
||||
Returns a shallow copy of this driver with a different default database.
|
||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
||||
"""
|
||||
cloned = copy.copy(self)
|
||||
cloned._database = database
|
||||
@abstractmethod
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
raise NotImplementedError()
|
||||
|
||||
return cloned
|
||||
def clone(self, database: str) -> 'GraphDriver':
|
||||
"""Clone the driver with a different database or graph name."""
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import logging
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from falkordb import Graph as FalkorGraph
|
||||
from falkordb.asyncio import FalkorDB
|
||||
|
|
@ -33,6 +35,8 @@ else:
|
|||
) from None
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -72,6 +76,8 @@ class FalkorDriverSession(GraphDriverSession):
|
|||
|
||||
class FalkorDriver(GraphDriver):
|
||||
provider = GraphProvider.FALKORDB
|
||||
default_group_id: str = '\\_'
|
||||
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -80,7 +86,7 @@ class FalkorDriver(GraphDriver):
|
|||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
falkor_db: FalkorDB | None = None,
|
||||
database: str = 'default_db',
|
||||
database: str = '\\_',
|
||||
):
|
||||
"""
|
||||
Initialize the FalkorDB driver.
|
||||
|
|
@ -98,7 +104,16 @@ class FalkorDriver(GraphDriver):
|
|||
else:
|
||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
|
||||
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
||||
# Schedule the indices and constraints to be built
|
||||
import asyncio
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# Schedule the build_indices_and_constraints to run
|
||||
loop.create_task(self.build_indices_and_constraints())
|
||||
except RuntimeError:
|
||||
# No event loop running, this will be handled later
|
||||
pass
|
||||
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||
|
|
@ -152,8 +167,64 @@ class FalkorDriver(GraphDriver):
|
|||
await self.client.connection.close()
|
||||
|
||||
async def delete_all_indexes(self) -> None:
|
||||
await self.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
from collections import defaultdict
|
||||
|
||||
result = await self.execute_query('CALL db.indexes()')
|
||||
if result is None:
|
||||
return
|
||||
|
||||
records, _, _ = result
|
||||
|
||||
# Organize indexes by type and label
|
||||
range_indexes = defaultdict(list)
|
||||
fulltext_indexes = defaultdict(list)
|
||||
entity_types = {}
|
||||
|
||||
for record in records:
|
||||
label = record['label']
|
||||
entity_types[label] = record['entitytype']
|
||||
|
||||
for field_name, index_type in record['types'].items():
|
||||
if 'RANGE' in index_type:
|
||||
range_indexes[label].append(field_name)
|
||||
if 'FULLTEXT' in index_type:
|
||||
fulltext_indexes[label].append(field_name)
|
||||
|
||||
# Drop all range indexes
|
||||
for label, fields in range_indexes.items():
|
||||
for field in fields:
|
||||
await self.execute_query(f'DROP INDEX ON :{label}({field})')
|
||||
|
||||
# Drop all fulltext indexes
|
||||
for label, fields in fulltext_indexes.items():
|
||||
entity_type = entity_types[label]
|
||||
for field in fields:
|
||||
if entity_type == 'NODE':
|
||||
await self.execute_query(
|
||||
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
|
||||
)
|
||||
elif entity_type == 'RELATIONSHIP':
|
||||
await self.execute_query(
|
||||
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
|
||||
)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
await self.delete_all_indexes()
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
self.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
||||
def clone(self, database: str) -> 'GraphDriver':
|
||||
|
|
@ -161,8 +232,12 @@ class FalkorDriver(GraphDriver):
|
|||
Returns a shallow copy of this driver with a different default database.
|
||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
||||
"""
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
if database == self._database:
|
||||
cloned = self
|
||||
else:
|
||||
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
return cloned
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider = GraphProvider.NEO4J
|
||||
default_group_id: str = ''
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||
super().__init__()
|
||||
|
|
@ -36,6 +39,17 @@ class Neo4jDriver(GraphDriver):
|
|||
auth=(user or '', password or ''),
|
||||
)
|
||||
self._database = database
|
||||
|
||||
# Schedule the indices and constraints to be built
|
||||
import asyncio
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# Schedule the build_indices_and_constraints to run
|
||||
loop.create_task(self.build_indices_and_constraints())
|
||||
except RuntimeError:
|
||||
# No event loop running, this will be handled later
|
||||
pass
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||
# Check if database_ is provided in kwargs.
|
||||
|
|
@ -64,3 +78,22 @@ class Neo4jDriver(GraphDriver):
|
|||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
||||
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
await self.delete_all_indexes()
|
||||
|
||||
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
self.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from typing_extensions import LiteralString
|
|||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.decorators import handle_multiple_group_ids
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.edges import (
|
||||
|
|
@ -35,7 +36,6 @@ from graphiti_core.edges import (
|
|||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import (
|
||||
get_default_group_id,
|
||||
semaphore_gather,
|
||||
validate_excluded_entity_types,
|
||||
validate_group_id,
|
||||
|
|
@ -87,7 +87,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
|
|
@ -320,25 +319,26 @@ class Graphiti:
|
|||
-----
|
||||
This method should typically be called once during the initial setup of the
|
||||
knowledge graph or when updating the database schema. It uses the
|
||||
`build_indices_and_constraints` function from the
|
||||
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
||||
driver's `build_indices_and_constraints` method to perform
|
||||
the actual database operations.
|
||||
|
||||
The specific indices and constraints created depend on the implementation
|
||||
of the `build_indices_and_constraints` function. Refer to that function's
|
||||
documentation for details on the exact database schema modifications.
|
||||
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
||||
driver documentation for details on the exact database schema modifications.
|
||||
|
||||
Caution: Running this method on a large existing database may take some time
|
||||
and could impact database performance during execution.
|
||||
"""
|
||||
await build_indices_and_constraints(self.driver, delete_existing)
|
||||
await self.driver.build_indices_and_constraints(delete_existing)
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
source: EpisodeType | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -365,7 +365,10 @@ class Graphiti:
|
|||
The actual retrieval is performed by the `retrieve_episodes` function
|
||||
from the `graphiti_core.utils` module.
|
||||
"""
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
||||
if driver is None:
|
||||
driver = self.clients.driver
|
||||
|
||||
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
|
|
@ -442,12 +445,18 @@ class Graphiti:
|
|||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_entity_types(entity_types)
|
||||
if group_id is None:
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = self.driver.default_group_id
|
||||
else:
|
||||
validate_group_id(group_id)
|
||||
if group_id != self.driver._database:
|
||||
# if group_id is provided, use it as the database name
|
||||
self.driver = self.driver.clone(database=group_id)
|
||||
self.clients.driver = self.driver
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
|
|
@ -620,9 +629,15 @@ class Graphiti:
|
|||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_group_id(group_id)
|
||||
if group_id is None:
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = self.driver.default_group_id
|
||||
else:
|
||||
validate_group_id(group_id)
|
||||
if group_id != self.driver._database:
|
||||
# if group_id is provided, use it as the database name
|
||||
self.driver = self.driver.clone(database=group_id)
|
||||
self.clients.driver = self.driver
|
||||
|
||||
# Create default edge type map
|
||||
edge_type_map_default = (
|
||||
|
|
@ -850,21 +865,26 @@ class Graphiti:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def build_communities(
|
||||
self, group_ids: list[str] | None = None
|
||||
self, group_ids: list[str] | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
"""
|
||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||
the content of these communities.
|
||||
----------
|
||||
query : list[str] | None
|
||||
group_ids : list[str] | None
|
||||
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
||||
"""
|
||||
if driver is None:
|
||||
driver = self.clients.driver
|
||||
|
||||
# Clear existing communities
|
||||
await remove_communities(self.driver)
|
||||
await remove_communities(driver)
|
||||
|
||||
community_nodes, community_edges = await build_communities(
|
||||
self.driver, self.llm_client, group_ids
|
||||
driver, self.llm_client, group_ids
|
||||
)
|
||||
|
||||
await semaphore_gather(
|
||||
|
|
@ -873,16 +893,17 @@ class Graphiti:
|
|||
)
|
||||
|
||||
await semaphore_gather(
|
||||
*[node.save(self.driver) for node in community_nodes],
|
||||
*[node.save(driver) for node in community_nodes],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
await semaphore_gather(
|
||||
*[edge.save(self.driver) for edge in community_edges],
|
||||
*[edge.save(driver) for edge in community_edges],
|
||||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
return community_nodes, community_edges
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -890,6 +911,7 @@ class Graphiti:
|
|||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
|
@ -936,7 +958,8 @@ class Graphiti:
|
|||
group_ids,
|
||||
search_config,
|
||||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
driver=driver,
|
||||
center_node_uuid=center_node_uuid
|
||||
)
|
||||
).edges
|
||||
|
||||
|
|
@ -956,6 +979,7 @@ class Graphiti:
|
|||
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
||||
)
|
||||
|
||||
@handle_multiple_group_ids
|
||||
async def search_(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -964,6 +988,7 @@ class Graphiti:
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
search_filter: SearchFilters | None = None,
|
||||
driver: GraphDriver | None = None
|
||||
) -> SearchResults:
|
||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||
|
|
@ -980,6 +1005,7 @@ class Graphiti:
|
|||
search_filter if search_filter is not None else SearchFilters(),
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
driver=driver
|
||||
)
|
||||
|
||||
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
||||
|
|
|
|||
|
|
@ -53,17 +53,6 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
|
|||
)
|
||||
|
||||
|
||||
def get_default_group_id(provider: GraphProvider) -> str:
|
||||
"""
|
||||
This function differentiates the default group id based on the database type.
|
||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||
"""
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return '_'
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def lucene_sanitize(query: str) -> str:
|
||||
# Escape special characters from a query before passing into Lucene
|
||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
||||
|
|
|
|||
|
|
@ -72,10 +72,11 @@ async def search(
|
|||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
query_vector: list[float] | None = None,
|
||||
driver: GraphDriver | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
driver = clients.driver
|
||||
driver = driver or clients.driver
|
||||
embedder = clients.embedder
|
||||
cross_encoder = clients.cross_encoder
|
||||
|
||||
|
|
|
|||
|
|
@ -127,3 +127,34 @@ class SearchResults(BaseModel):
|
|||
episode_reranker_scores: list[float] = Field(default_factory=list)
|
||||
communities: list[CommunityNode] = Field(default_factory=list)
|
||||
community_reranker_scores: list[float] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
||||
"""
|
||||
Merge multiple SearchResults objects into a single SearchResults object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
results_list : list[SearchResults]
|
||||
List of SearchResults objects to merge
|
||||
|
||||
Returns
|
||||
-------
|
||||
SearchResults
|
||||
A single SearchResults object containing all results
|
||||
"""
|
||||
if not results_list:
|
||||
return cls()
|
||||
|
||||
merged = cls()
|
||||
for result in results_list:
|
||||
merged.edges.extend(result.edges)
|
||||
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
||||
merged.nodes.extend(result.nodes)
|
||||
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
||||
merged.episodes.extend(result.episodes)
|
||||
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
||||
merged.communities.extend(result.communities)
|
||||
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
||||
|
||||
return merged
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ from datetime import datetime
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||
|
||||
|
|
@ -30,39 +28,6 @@ EPISODE_WINDOW_LEN = 3
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
SHOW INDEXES YIELD name
|
||||
""",
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await semaphore_gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
"""DROP INDEX $name""",
|
||||
name=name,
|
||||
)
|
||||
for name in index_names
|
||||
]
|
||||
)
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||
async with driver.session() as session:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue