[Improvement] Add GraphID isolation support for FalkorDB multi-tenant architecture (#835)
* Update node_db_queries.py * Update node_db_queries.py * graph-per-graphid * fix-groupid-usage * ruff-fix * rev-driver-changes * rm-un-changes * fix lint --------- Co-authored-by: Naseem Ali <34807727+Naseem77@users.noreply.github.com>
This commit is contained in:
parent
8d99984204
commit
c144ff5995
10 changed files with 267 additions and 72 deletions
|
|
@ -78,9 +78,6 @@ async def main():
|
||||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
|
||||||
await graphiti.build_indices_and_constraints()
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# ADDING EPISODES
|
# ADDING EPISODES
|
||||||
#################################################
|
#################################################
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,6 @@ async def main():
|
||||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
|
||||||
await graphiti.build_indices_and_constraints()
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# ADDING EPISODES
|
# ADDING EPISODES
|
||||||
#################################################
|
#################################################
|
||||||
|
|
|
||||||
110
graphiti_core/decorators.py
Normal file
110
graphiti_core/decorators.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
from graphiti_core.search.search_config import SearchResults
|
||||||
|
|
||||||
|
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
||||||
|
|
||||||
|
|
||||||
|
def handle_multiple_group_ids(func: F) -> F:
|
||||||
|
"""
|
||||||
|
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
||||||
|
Runs the function for each group_id separately and merges results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
group_ids_func_pos = get_parameter_position(func, 'group_ids')
|
||||||
|
group_ids_pos = (
|
||||||
|
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
|
||||||
|
) # Adjust for zero-based index
|
||||||
|
group_ids = kwargs.get('group_ids')
|
||||||
|
|
||||||
|
# If not in kwargs and position exists, get from args
|
||||||
|
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
|
||||||
|
group_ids = args[group_ids_pos]
|
||||||
|
|
||||||
|
# Only handle FalkorDB with multiple group_ids
|
||||||
|
if (
|
||||||
|
hasattr(self, 'clients')
|
||||||
|
and hasattr(self.clients, 'driver')
|
||||||
|
and self.clients.driver.provider == GraphProvider.FALKORDB
|
||||||
|
and group_ids
|
||||||
|
and len(group_ids) > 1
|
||||||
|
):
|
||||||
|
# Execute for each group_id concurrently
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
|
async def execute_for_group(gid: str):
|
||||||
|
# Remove group_ids from args if it was passed positionally
|
||||||
|
filtered_args = list(args)
|
||||||
|
if group_ids_pos is not None and len(args) > group_ids_pos:
|
||||||
|
filtered_args.pop(group_ids_pos)
|
||||||
|
|
||||||
|
return await func(
|
||||||
|
self,
|
||||||
|
*filtered_args,
|
||||||
|
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await semaphore_gather(
|
||||||
|
*[execute_for_group(gid) for gid in group_ids],
|
||||||
|
max_coroutines=getattr(self, 'max_coroutines', None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge results based on type
|
||||||
|
if isinstance(results[0], SearchResults):
|
||||||
|
return SearchResults.merge(results)
|
||||||
|
elif isinstance(results[0], list):
|
||||||
|
return [item for result in results for item in result]
|
||||||
|
elif isinstance(results[0], tuple):
|
||||||
|
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
||||||
|
merged_tuple = []
|
||||||
|
for i in range(len(results[0])):
|
||||||
|
component_results = [result[i] for result in results]
|
||||||
|
if isinstance(component_results[0], list):
|
||||||
|
merged_tuple.append(
|
||||||
|
[item for component in component_results for item in component]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_tuple.append(component_results)
|
||||||
|
return tuple(merged_tuple)
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Normal execution
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_position(func: Callable, param_name: str) -> int | None:
|
||||||
|
"""
|
||||||
|
Returns the positional index of a parameter in the function signature.
|
||||||
|
If the parameter is not found, returns None.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
for idx, (name, _param) in enumerate(sig.parameters.items()):
|
||||||
|
if name == param_name:
|
||||||
|
return idx
|
||||||
|
return None
|
||||||
|
|
@ -76,6 +76,7 @@ class GraphDriver(ABC):
|
||||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||||
)
|
)
|
||||||
_database: str
|
_database: str
|
||||||
|
default_group_id: str = ''
|
||||||
search_interface: SearchInterface | None = None
|
search_interface: SearchInterface | None = None
|
||||||
graph_operations_interface: GraphOperationsInterface | None = None
|
graph_operations_interface: GraphOperationsInterface | None = None
|
||||||
|
|
||||||
|
|
@ -105,6 +106,14 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def clone(self, database: str) -> 'GraphDriver':
|
||||||
|
"""Clone the driver with a different database or graph name."""
|
||||||
|
return self
|
||||||
|
|
||||||
def build_fulltext_query(
|
def build_fulltext_query(
|
||||||
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ else:
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -112,6 +113,8 @@ class FalkorDriverSession(GraphDriverSession):
|
||||||
|
|
||||||
class FalkorDriver(GraphDriver):
|
class FalkorDriver(GraphDriver):
|
||||||
provider = GraphProvider.FALKORDB
|
provider = GraphProvider.FALKORDB
|
||||||
|
default_group_id: str = '\\_'
|
||||||
|
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
|
||||||
aoss_client: None = None
|
aoss_client: None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -129,9 +132,16 @@ class FalkorDriver(GraphDriver):
|
||||||
FalkorDB is a multi-tenant graph database.
|
FalkorDB is a multi-tenant graph database.
|
||||||
To connect, provide the host and port.
|
To connect, provide the host and port.
|
||||||
The default parameters assume a local (on-premises) FalkorDB instance.
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (str): The host where FalkorDB is running.
|
||||||
|
port (int): The port on which FalkorDB is listening.
|
||||||
|
username (str | None): The username for authentication (if required).
|
||||||
|
password (str | None): The password for authentication (if required).
|
||||||
|
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
|
||||||
|
database (str): The name of the database to connect to. Defaults to 'default_db'.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._database = database
|
self._database = database
|
||||||
if falkor_db is not None:
|
if falkor_db is not None:
|
||||||
# If a FalkorDB instance is provided, use it directly
|
# If a FalkorDB instance is provided, use it directly
|
||||||
|
|
@ -139,7 +149,15 @@ class FalkorDriver(GraphDriver):
|
||||||
else:
|
else:
|
||||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||||
|
|
||||||
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
# Schedule the indices and constraints to be built
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# Schedule the build_indices_and_constraints to run
|
||||||
|
loop.create_task(self.build_indices_and_constraints())
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop running, this will be handled later
|
||||||
|
pass
|
||||||
|
|
||||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||||
|
|
@ -224,12 +242,25 @@ class FalkorDriver(GraphDriver):
|
||||||
if drop_tasks:
|
if drop_tasks:
|
||||||
await asyncio.gather(*drop_tasks)
|
await asyncio.gather(*drop_tasks)
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing=False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
|
||||||
|
for query in index_queries:
|
||||||
|
await self.execute_query(query)
|
||||||
|
|
||||||
def clone(self, database: str) -> 'GraphDriver':
|
def clone(self, database: str) -> 'GraphDriver':
|
||||||
"""
|
"""
|
||||||
Returns a shallow copy of this driver with a different default database.
|
Returns a shallow copy of this driver with a different default database.
|
||||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
||||||
"""
|
"""
|
||||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
if database == self._database:
|
||||||
|
cloned = self
|
||||||
|
elif database == self.default_group_id:
|
||||||
|
cloned = FalkorDriver(falkor_db=self.client)
|
||||||
|
else:
|
||||||
|
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||||
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Neo4jDriver(GraphDriver):
|
class Neo4jDriver(GraphDriver):
|
||||||
provider = GraphProvider.NEO4J
|
provider = GraphProvider.NEO4J
|
||||||
|
default_group_id: str = ''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -43,6 +46,18 @@ class Neo4jDriver(GraphDriver):
|
||||||
)
|
)
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
|
# Schedule the indices and constraints to be built
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# Schedule the build_indices_and_constraints to run
|
||||||
|
loop.create_task(self.build_indices_and_constraints())
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop running, this will be handled later
|
||||||
|
pass
|
||||||
|
|
||||||
self.aoss_client = None
|
self.aoss_client = None
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||||
|
|
@ -73,6 +88,25 @@ class Neo4jDriver(GraphDriver):
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||||
|
|
||||||
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||||
|
|
||||||
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.execute_query(
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
async def health_check(self) -> None:
|
async def health_check(self) -> None:
|
||||||
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
from graphiti_core.decorators import handle_multiple_group_ids
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.edges import (
|
from graphiti_core.edges import (
|
||||||
|
|
@ -87,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
build_indices_and_constraints,
|
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import (
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
|
|
@ -340,18 +340,17 @@ class Graphiti:
|
||||||
-----
|
-----
|
||||||
This method should typically be called once during the initial setup of the
|
This method should typically be called once during the initial setup of the
|
||||||
knowledge graph or when updating the database schema. It uses the
|
knowledge graph or when updating the database schema. It uses the
|
||||||
`build_indices_and_constraints` function from the
|
driver's `build_indices_and_constraints` method to perform
|
||||||
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
||||||
the actual database operations.
|
the actual database operations.
|
||||||
|
|
||||||
The specific indices and constraints created depend on the implementation
|
The specific indices and constraints created depend on the implementation
|
||||||
of the `build_indices_and_constraints` function. Refer to that function's
|
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
||||||
documentation for details on the exact database schema modifications.
|
driver documentation for details on the exact database schema modifications.
|
||||||
|
|
||||||
Caution: Running this method on a large existing database may take some time
|
Caution: Running this method on a large existing database may take some time
|
||||||
and could impact database performance during execution.
|
and could impact database performance during execution.
|
||||||
"""
|
"""
|
||||||
await build_indices_and_constraints(self.driver, delete_existing)
|
await self.driver.build_indices_and_constraints(delete_existing)
|
||||||
|
|
||||||
async def _extract_and_resolve_nodes(
|
async def _extract_and_resolve_nodes(
|
||||||
self,
|
self,
|
||||||
|
|
@ -574,12 +573,14 @@ class Graphiti:
|
||||||
|
|
||||||
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
source: EpisodeType | None = None,
|
source: EpisodeType | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -606,7 +607,10 @@ class Graphiti:
|
||||||
The actual retrieval is performed by the `retrieve_episodes` function
|
The actual retrieval is performed by the `retrieve_episodes` function
|
||||||
from the `graphiti_core.utils` module.
|
from the `graphiti_core.utils` module.
|
||||||
"""
|
"""
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
if driver is None:
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
|
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
|
|
@ -683,11 +687,18 @@ class Graphiti:
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
validate_entity_types(entity_types)
|
validate_entity_types(entity_types)
|
||||||
|
|
||||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||||
validate_group_id(group_id)
|
|
||||||
# if group_id is None, use the default group id by the provider
|
if group_id is None:
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
# if group_id is None, use the default group id by the provider
|
||||||
|
# and the preset database name will be used
|
||||||
|
group_id = get_default_group_id(self.driver.provider)
|
||||||
|
else:
|
||||||
|
validate_group_id(group_id)
|
||||||
|
if group_id != self.driver._database:
|
||||||
|
# if group_id is provided, use it as the database name
|
||||||
|
self.driver = self.driver.clone(database=group_id)
|
||||||
|
self.clients.driver = self.driver
|
||||||
|
|
||||||
with self.tracer.start_span('add_episode') as span:
|
with self.tracer.start_span('add_episode') as span:
|
||||||
try:
|
try:
|
||||||
|
|
@ -865,8 +876,14 @@ class Graphiti:
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
# if group_id is None, use the default group id by the provider
|
# if group_id is None, use the default group id by the provider
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
if group_id is None:
|
||||||
validate_group_id(group_id)
|
group_id = get_default_group_id(self.driver.provider)
|
||||||
|
else:
|
||||||
|
validate_group_id(group_id)
|
||||||
|
if group_id != self.driver._database:
|
||||||
|
# if group_id is provided, use it as the database name
|
||||||
|
self.driver = self.driver.clone(database=group_id)
|
||||||
|
self.clients.driver = self.driver
|
||||||
|
|
||||||
# Create default edge type map
|
# Create default edge type map
|
||||||
edge_type_map_default = (
|
edge_type_map_default = (
|
||||||
|
|
@ -993,21 +1010,25 @@ class Graphiti:
|
||||||
bulk_span.record_exception(e)
|
bulk_span.record_exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
self, group_ids: list[str] | None = None
|
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
"""
|
"""
|
||||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||||
the content of these communities.
|
the content of these communities.
|
||||||
----------
|
----------
|
||||||
query : list[str] | None
|
group_ids : list[str] | None
|
||||||
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
||||||
"""
|
"""
|
||||||
|
if driver is None:
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
# Clear existing communities
|
# Clear existing communities
|
||||||
await remove_communities(self.driver)
|
await remove_communities(driver)
|
||||||
|
|
||||||
community_nodes, community_edges = await build_communities(
|
community_nodes, community_edges = await build_communities(
|
||||||
self.driver, self.llm_client, group_ids
|
driver, self.llm_client, group_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -1016,16 +1037,17 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[node.save(self.driver) for node in community_nodes],
|
*[node.save(driver) for node in community_nodes],
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[edge.save(self.driver) for edge in community_edges],
|
*[edge.save(driver) for edge in community_edges],
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
return community_nodes, community_edges
|
return community_nodes, community_edges
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -1033,6 +1055,7 @@ class Graphiti:
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
num_results=DEFAULT_SEARCH_LIMIT,
|
num_results=DEFAULT_SEARCH_LIMIT,
|
||||||
search_filter: SearchFilters | None = None,
|
search_filter: SearchFilters | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search on the knowledge graph.
|
Perform a hybrid search on the knowledge graph.
|
||||||
|
|
@ -1079,7 +1102,8 @@ class Graphiti:
|
||||||
group_ids,
|
group_ids,
|
||||||
search_config,
|
search_config,
|
||||||
search_filter if search_filter is not None else SearchFilters(),
|
search_filter if search_filter is not None else SearchFilters(),
|
||||||
center_node_uuid,
|
driver=driver,
|
||||||
|
center_node_uuid=center_node_uuid,
|
||||||
)
|
)
|
||||||
).edges
|
).edges
|
||||||
|
|
||||||
|
|
@ -1099,6 +1123,7 @@ class Graphiti:
|
||||||
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def search_(
|
async def search_(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -1107,6 +1132,7 @@ class Graphiti:
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
bfs_origin_node_uuids: list[str] | None = None,
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
search_filter: SearchFilters | None = None,
|
search_filter: SearchFilters | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||||
|
|
@ -1123,6 +1149,7 @@ class Graphiti:
|
||||||
search_filter if search_filter is not None else SearchFilters(),
|
search_filter if search_filter is not None else SearchFilters(),
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
bfs_origin_node_uuids,
|
bfs_origin_node_uuids,
|
||||||
|
driver=driver,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,11 @@ async def search(
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
bfs_origin_node_uuids: list[str] | None = None,
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
query_vector: list[float] | None = None,
|
query_vector: list[float] | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
driver = clients.driver
|
driver = driver or clients.driver
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
cross_encoder = clients.cross_encoder
|
cross_encoder = clients.cross_encoder
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,3 +127,34 @@ class SearchResults(BaseModel):
|
||||||
episode_reranker_scores: list[float] = Field(default_factory=list)
|
episode_reranker_scores: list[float] = Field(default_factory=list)
|
||||||
communities: list[CommunityNode] = Field(default_factory=list)
|
communities: list[CommunityNode] = Field(default_factory=list)
|
||||||
community_reranker_scores: list[float] = Field(default_factory=list)
|
community_reranker_scores: list[float] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
||||||
|
"""
|
||||||
|
Merge multiple SearchResults objects into a single SearchResults object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
results_list : list[SearchResults]
|
||||||
|
List of SearchResults objects to merge
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SearchResults
|
||||||
|
A single SearchResults object containing all results
|
||||||
|
"""
|
||||||
|
if not results_list:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
merged = cls()
|
||||||
|
for result in results_list:
|
||||||
|
merged.edges.extend(result.edges)
|
||||||
|
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
||||||
|
merged.nodes.extend(result.nodes)
|
||||||
|
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
||||||
|
merged.episodes.extend(result.episodes)
|
||||||
|
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
||||||
|
merged.communities.extend(result.communities)
|
||||||
|
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ from datetime import datetime
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
||||||
from graphiti_core.helpers import semaphore_gather
|
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
EPISODIC_NODE_RETURN,
|
EPISODIC_NODE_RETURN,
|
||||||
EPISODIC_NODE_RETURN_NEPTUNE,
|
EPISODIC_NODE_RETURN_NEPTUNE,
|
||||||
|
|
@ -33,46 +31,6 @@ EPISODE_WINDOW_LEN = 3
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
||||||
if delete_existing:
|
|
||||||
await driver.delete_all_indexes()
|
|
||||||
|
|
||||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
||||||
|
|
||||||
# Don't create fulltext indices if search_interface is being used
|
|
||||||
if not driver.search_interface:
|
|
||||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
||||||
|
|
||||||
if driver.provider == GraphProvider.KUZU:
|
|
||||||
# Skip creating fulltext indices if they already exist. Need to do this manually
|
|
||||||
# until Kuzu supports `IF NOT EXISTS` for indices.
|
|
||||||
result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
|
|
||||||
if len(result) > 0:
|
|
||||||
fulltext_indices = []
|
|
||||||
|
|
||||||
# Only load the `fts` extension if it's not already loaded, otherwise throw an error.
|
|
||||||
result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
|
|
||||||
if len(result) == 0:
|
|
||||||
fulltext_indices.insert(
|
|
||||||
0,
|
|
||||||
"""
|
|
||||||
INSTALL fts;
|
|
||||||
LOAD fts;
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
||||||
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
driver.execute_query(
|
|
||||||
query,
|
|
||||||
)
|
|
||||||
for query in index_queries
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
async with driver.session() as session:
|
async with driver.session() as session:
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue