graph-per-graphid
This commit is contained in:
parent
4f8eb310f2
commit
bbf9cc6172
11 changed files with 279 additions and 87 deletions
|
|
@ -78,8 +78,6 @@ async def main():
|
||||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
|
||||||
await graphiti.build_indices_and_constraints()
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# ADDING EPISODES
|
# ADDING EPISODES
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,6 @@ async def main():
|
||||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
|
||||||
await graphiti.build_indices_and_constraints()
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# ADDING EPISODES
|
# ADDING EPISODES
|
||||||
|
|
|
||||||
77
graphiti_core/decorators.py
Normal file
77
graphiti_core/decorators.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Any, Awaitable, Callable, TypeVar
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
from graphiti_core.search.search_config import SearchResults
|
||||||
|
|
||||||
|
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
||||||
|
|
||||||
|
|
||||||
|
def handle_multiple_group_ids(func: F) -> F:
|
||||||
|
"""
|
||||||
|
Decorator for FalkorDB methods that need to handle multiple group_ids.
|
||||||
|
Runs the function for each group_id separately and merges results.
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
group_ids = kwargs.get('group_ids')
|
||||||
|
|
||||||
|
# Only handle FalkorDB with multiple group_ids
|
||||||
|
if (hasattr(self, 'clients') and hasattr(self.clients, 'driver') and
|
||||||
|
self.clients.driver.provider == GraphProvider.FALKORDB and
|
||||||
|
group_ids and len(group_ids) > 1):
|
||||||
|
|
||||||
|
# Execute for each group_id concurrently
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
|
async def execute_for_group(gid: str):
|
||||||
|
return await func(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**{**kwargs, "group_ids": [gid], "driver": driver.clone(database=gid)},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await semaphore_gather(
|
||||||
|
*[execute_for_group(gid) for gid in group_ids],
|
||||||
|
max_coroutines=getattr(self, 'max_coroutines', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge results based on type
|
||||||
|
if isinstance(results[0], SearchResults):
|
||||||
|
return SearchResults.merge(results)
|
||||||
|
elif isinstance(results[0], list):
|
||||||
|
return [item for result in results for item in result]
|
||||||
|
elif isinstance(results[0], tuple):
|
||||||
|
# Handle tuple outputs (like build_communities returning (nodes, edges))
|
||||||
|
merged_tuple = []
|
||||||
|
for i in range(len(results[0])):
|
||||||
|
component_results = [result[i] for result in results]
|
||||||
|
if isinstance(component_results[0], list):
|
||||||
|
merged_tuple.append([item for component in component_results for item in component])
|
||||||
|
else:
|
||||||
|
merged_tuple.append(component_results)
|
||||||
|
return tuple(merged_tuple)
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Normal execution
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
@ -57,6 +57,7 @@ class GraphDriver(ABC):
|
||||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||||
)
|
)
|
||||||
_database: str
|
_database: str
|
||||||
|
default_group_id: str = ''
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||||
|
|
@ -74,12 +75,10 @@ class GraphDriver(ABC):
|
||||||
def delete_all_indexes(self) -> Coroutine:
|
def delete_all_indexes(self) -> Coroutine:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def with_database(self, database: str) -> 'GraphDriver':
|
@abstractmethod
|
||||||
"""
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
Returns a shallow copy of this driver with a different default database.
|
raise NotImplementedError()
|
||||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
||||||
"""
|
|
||||||
cloned = copy.copy(self)
|
|
||||||
cloned._database = database
|
|
||||||
|
|
||||||
return cloned
|
def clone(self, database: str) -> 'GraphDriver':
|
||||||
|
"""Clone the driver with a different database or graph name."""
|
||||||
|
return self
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from falkordb import Graph as FalkorGraph
|
from falkordb import Graph as FalkorGraph
|
||||||
from falkordb.asyncio import FalkorDB
|
from falkordb.asyncio import FalkorDB
|
||||||
|
|
@ -33,6 +35,8 @@ else:
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -72,6 +76,8 @@ class FalkorDriverSession(GraphDriverSession):
|
||||||
|
|
||||||
class FalkorDriver(GraphDriver):
|
class FalkorDriver(GraphDriver):
|
||||||
provider = GraphProvider.FALKORDB
|
provider = GraphProvider.FALKORDB
|
||||||
|
default_group_id: str = '\\_'
|
||||||
|
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -80,7 +86,7 @@ class FalkorDriver(GraphDriver):
|
||||||
username: str | None = None,
|
username: str | None = None,
|
||||||
password: str | None = None,
|
password: str | None = None,
|
||||||
falkor_db: FalkorDB | None = None,
|
falkor_db: FalkorDB | None = None,
|
||||||
database: str = 'default_db',
|
database: str = '\\_',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the FalkorDB driver.
|
Initialize the FalkorDB driver.
|
||||||
|
|
@ -98,7 +104,16 @@ class FalkorDriver(GraphDriver):
|
||||||
else:
|
else:
|
||||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||||
|
|
||||||
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
# Schedule the indices and constraints to be built
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# Schedule the build_indices_and_constraints to run
|
||||||
|
loop.create_task(self.build_indices_and_constraints())
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop running, this will be handled later
|
||||||
|
pass
|
||||||
|
|
||||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||||
|
|
@ -152,8 +167,64 @@ class FalkorDriver(GraphDriver):
|
||||||
await self.client.connection.close()
|
await self.client.connection.close()
|
||||||
|
|
||||||
async def delete_all_indexes(self) -> None:
|
async def delete_all_indexes(self) -> None:
|
||||||
await self.execute_query(
|
from collections import defaultdict
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
||||||
|
result = await self.execute_query('CALL db.indexes()')
|
||||||
|
if result is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
records, _, _ = result
|
||||||
|
|
||||||
|
# Organize indexes by type and label
|
||||||
|
range_indexes = defaultdict(list)
|
||||||
|
fulltext_indexes = defaultdict(list)
|
||||||
|
entity_types = {}
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
label = record['label']
|
||||||
|
entity_types[label] = record['entitytype']
|
||||||
|
|
||||||
|
for field_name, index_type in record['types'].items():
|
||||||
|
if 'RANGE' in index_type:
|
||||||
|
range_indexes[label].append(field_name)
|
||||||
|
if 'FULLTEXT' in index_type:
|
||||||
|
fulltext_indexes[label].append(field_name)
|
||||||
|
|
||||||
|
# Drop all range indexes
|
||||||
|
for label, fields in range_indexes.items():
|
||||||
|
for field in fields:
|
||||||
|
await self.execute_query(f'DROP INDEX ON :{label}({field})')
|
||||||
|
|
||||||
|
# Drop all fulltext indexes
|
||||||
|
for label, fields in fulltext_indexes.items():
|
||||||
|
entity_type = entity_types[label]
|
||||||
|
for field in fields:
|
||||||
|
if entity_type == 'NODE':
|
||||||
|
await self.execute_query(
|
||||||
|
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field})'
|
||||||
|
)
|
||||||
|
elif entity_type == 'RELATIONSHIP':
|
||||||
|
await self.execute_query(
|
||||||
|
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field})'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||||
|
|
||||||
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||||
|
|
||||||
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.execute_query(
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def clone(self, database: str) -> 'GraphDriver':
|
def clone(self, database: str) -> 'GraphDriver':
|
||||||
|
|
@ -161,8 +232,12 @@ class FalkorDriver(GraphDriver):
|
||||||
Returns a shallow copy of this driver with a different default database.
|
Returns a shallow copy of this driver with a different default database.
|
||||||
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
||||||
"""
|
"""
|
||||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
if database == self._database:
|
||||||
|
cloned = self
|
||||||
|
else:
|
||||||
|
# Create a new instance of FalkorDriver with the same connection but a different database
|
||||||
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,15 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Neo4jDriver(GraphDriver):
|
class Neo4jDriver(GraphDriver):
|
||||||
provider = GraphProvider.NEO4J
|
provider = GraphProvider.NEO4J
|
||||||
|
default_group_id: str = ''
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -36,6 +39,17 @@ class Neo4jDriver(GraphDriver):
|
||||||
auth=(user or '', password or ''),
|
auth=(user or '', password or ''),
|
||||||
)
|
)
|
||||||
self._database = database
|
self._database = database
|
||||||
|
|
||||||
|
# Schedule the indices and constraints to be built
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
# Try to get the current event loop
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# Schedule the build_indices_and_constraints to run
|
||||||
|
loop.create_task(self.build_indices_and_constraints())
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop running, this will be handled later
|
||||||
|
pass
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
||||||
# Check if database_ is provided in kwargs.
|
# Check if database_ is provided in kwargs.
|
||||||
|
|
@ -64,3 +78,22 @@ class Neo4jDriver(GraphDriver):
|
||||||
return self.client.execute_query(
|
return self.client.execute_query(
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
||||||
|
if delete_existing:
|
||||||
|
await self.delete_all_indexes()
|
||||||
|
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
||||||
|
|
||||||
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
||||||
|
|
||||||
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
|
await semaphore_gather(
|
||||||
|
*[
|
||||||
|
self.execute_query(
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
from graphiti_core.decorators import handle_multiple_group_ids
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.edges import (
|
from graphiti_core.edges import (
|
||||||
|
|
@ -35,7 +36,6 @@ from graphiti_core.edges import (
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import (
|
from graphiti_core.helpers import (
|
||||||
get_default_group_id,
|
|
||||||
semaphore_gather,
|
semaphore_gather,
|
||||||
validate_excluded_entity_types,
|
validate_excluded_entity_types,
|
||||||
validate_group_id,
|
validate_group_id,
|
||||||
|
|
@ -87,7 +87,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
EPISODE_WINDOW_LEN,
|
EPISODE_WINDOW_LEN,
|
||||||
build_indices_and_constraints,
|
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.node_operations import (
|
from graphiti_core.utils.maintenance.node_operations import (
|
||||||
|
|
@ -320,25 +319,26 @@ class Graphiti:
|
||||||
-----
|
-----
|
||||||
This method should typically be called once during the initial setup of the
|
This method should typically be called once during the initial setup of the
|
||||||
knowledge graph or when updating the database schema. It uses the
|
knowledge graph or when updating the database schema. It uses the
|
||||||
`build_indices_and_constraints` function from the
|
driver's `build_indices_and_constraints` method to perform
|
||||||
`graphiti_core.utils.maintenance.graph_data_operations` module to perform
|
|
||||||
the actual database operations.
|
the actual database operations.
|
||||||
|
|
||||||
The specific indices and constraints created depend on the implementation
|
The specific indices and constraints created depend on the implementation
|
||||||
of the `build_indices_and_constraints` function. Refer to that function's
|
of the driver's `build_indices_and_constraints` method. Refer to the specific
|
||||||
documentation for details on the exact database schema modifications.
|
driver documentation for details on the exact database schema modifications.
|
||||||
|
|
||||||
Caution: Running this method on a large existing database may take some time
|
Caution: Running this method on a large existing database may take some time
|
||||||
and could impact database performance during execution.
|
and could impact database performance during execution.
|
||||||
"""
|
"""
|
||||||
await build_indices_and_constraints(self.driver, delete_existing)
|
await self.driver.build_indices_and_constraints(delete_existing)
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
self,
|
self,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
source: EpisodeType | None = None,
|
source: EpisodeType | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
@ -365,7 +365,10 @@ class Graphiti:
|
||||||
The actual retrieval is performed by the `retrieve_episodes` function
|
The actual retrieval is performed by the `retrieve_episodes` function
|
||||||
from the `graphiti_core.utils` module.
|
from the `graphiti_core.utils` module.
|
||||||
"""
|
"""
|
||||||
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
|
if driver is None:
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
|
return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
|
|
@ -442,12 +445,18 @@ class Graphiti:
|
||||||
start = time()
|
start = time()
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
# if group_id is None, use the default group id by the provider
|
if group_id is None:
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
# if group_id is None, use the default group id by the provider
|
||||||
validate_entity_types(entity_types)
|
group_id = self.driver.default_group_id
|
||||||
|
else:
|
||||||
|
validate_group_id(group_id)
|
||||||
|
if group_id != self.driver._database:
|
||||||
|
# if group_id is provided, use it as the database name
|
||||||
|
self.driver = self.driver.clone(database=group_id)
|
||||||
|
self.clients.driver = self.driver
|
||||||
|
|
||||||
|
validate_entity_types(entity_types)
|
||||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||||
validate_group_id(group_id)
|
|
||||||
|
|
||||||
previous_episodes = (
|
previous_episodes = (
|
||||||
await self.retrieve_episodes(
|
await self.retrieve_episodes(
|
||||||
|
|
@ -620,9 +629,15 @@ class Graphiti:
|
||||||
start = time()
|
start = time()
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
# if group_id is None, use the default group id by the provider
|
if group_id is None:
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
# if group_id is None, use the default group id by the provider
|
||||||
validate_group_id(group_id)
|
group_id = self.driver.default_group_id
|
||||||
|
else:
|
||||||
|
validate_group_id(group_id)
|
||||||
|
if group_id != self.driver._database:
|
||||||
|
# if group_id is provided, use it as the database name
|
||||||
|
self.driver = self.driver.clone(database=group_id)
|
||||||
|
self.clients.driver = self.driver
|
||||||
|
|
||||||
# Create default edge type map
|
# Create default edge type map
|
||||||
edge_type_map_default = (
|
edge_type_map_default = (
|
||||||
|
|
@ -850,21 +865,26 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
self, group_ids: list[str] | None = None
|
self, group_ids: list[str] | None = None,
|
||||||
|
driver: GraphDriver | None = None
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
"""
|
"""
|
||||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||||
the content of these communities.
|
the content of these communities.
|
||||||
----------
|
----------
|
||||||
query : list[str] | None
|
group_ids : list[str] | None
|
||||||
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
|
||||||
"""
|
"""
|
||||||
|
if driver is None:
|
||||||
|
driver = self.clients.driver
|
||||||
|
|
||||||
# Clear existing communities
|
# Clear existing communities
|
||||||
await remove_communities(self.driver)
|
await remove_communities(driver)
|
||||||
|
|
||||||
community_nodes, community_edges = await build_communities(
|
community_nodes, community_edges = await build_communities(
|
||||||
self.driver, self.llm_client, group_ids
|
driver, self.llm_client, group_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -873,16 +893,17 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[node.save(self.driver) for node in community_nodes],
|
*[node.save(driver) for node in community_nodes],
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[edge.save(self.driver) for edge in community_edges],
|
*[edge.save(driver) for edge in community_edges],
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
return community_nodes, community_edges
|
return community_nodes, community_edges
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -890,6 +911,7 @@ class Graphiti:
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
num_results=DEFAULT_SEARCH_LIMIT,
|
num_results=DEFAULT_SEARCH_LIMIT,
|
||||||
search_filter: SearchFilters | None = None,
|
search_filter: SearchFilters | None = None,
|
||||||
|
driver: GraphDriver | None = None
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search on the knowledge graph.
|
Perform a hybrid search on the knowledge graph.
|
||||||
|
|
@ -936,7 +958,8 @@ class Graphiti:
|
||||||
group_ids,
|
group_ids,
|
||||||
search_config,
|
search_config,
|
||||||
search_filter if search_filter is not None else SearchFilters(),
|
search_filter if search_filter is not None else SearchFilters(),
|
||||||
center_node_uuid,
|
driver=driver,
|
||||||
|
center_node_uuid=center_node_uuid
|
||||||
)
|
)
|
||||||
).edges
|
).edges
|
||||||
|
|
||||||
|
|
@ -956,6 +979,7 @@ class Graphiti:
|
||||||
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@handle_multiple_group_ids
|
||||||
async def search_(
|
async def search_(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -964,6 +988,7 @@ class Graphiti:
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
bfs_origin_node_uuids: list[str] | None = None,
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
search_filter: SearchFilters | None = None,
|
search_filter: SearchFilters | None = None,
|
||||||
|
driver: GraphDriver | None = None
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
"""search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
|
||||||
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
|
||||||
|
|
@ -980,6 +1005,7 @@ class Graphiti:
|
||||||
search_filter if search_filter is not None else SearchFilters(),
|
search_filter if search_filter is not None else SearchFilters(),
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
bfs_origin_node_uuids,
|
bfs_origin_node_uuids,
|
||||||
|
driver=driver
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
|
||||||
|
|
|
||||||
|
|
@ -53,17 +53,6 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_group_id(provider: GraphProvider) -> str:
|
|
||||||
"""
|
|
||||||
This function differentiates the default group id based on the database type.
|
|
||||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
|
||||||
"""
|
|
||||||
if provider == GraphProvider.FALKORDB:
|
|
||||||
return '_'
|
|
||||||
else:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
def lucene_sanitize(query: str) -> str:
|
def lucene_sanitize(query: str) -> str:
|
||||||
# Escape special characters from a query before passing into Lucene
|
# Escape special characters from a query before passing into Lucene
|
||||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,11 @@ async def search(
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
bfs_origin_node_uuids: list[str] | None = None,
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
query_vector: list[float] | None = None,
|
query_vector: list[float] | None = None,
|
||||||
|
driver: GraphDriver | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
driver = clients.driver
|
driver = driver or clients.driver
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
cross_encoder = clients.cross_encoder
|
cross_encoder = clients.cross_encoder
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,3 +127,34 @@ class SearchResults(BaseModel):
|
||||||
episode_reranker_scores: list[float] = Field(default_factory=list)
|
episode_reranker_scores: list[float] = Field(default_factory=list)
|
||||||
communities: list[CommunityNode] = Field(default_factory=list)
|
communities: list[CommunityNode] = Field(default_factory=list)
|
||||||
community_reranker_scores: list[float] = Field(default_factory=list)
|
community_reranker_scores: list[float] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
||||||
|
"""
|
||||||
|
Merge multiple SearchResults objects into a single SearchResults object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
results_list : list[SearchResults]
|
||||||
|
List of SearchResults objects to merge
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SearchResults
|
||||||
|
A single SearchResults object containing all results
|
||||||
|
"""
|
||||||
|
if not results_list:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
merged = cls()
|
||||||
|
for result in results_list:
|
||||||
|
merged.edges.extend(result.edges)
|
||||||
|
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
||||||
|
merged.nodes.extend(result.nodes)
|
||||||
|
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
||||||
|
merged.episodes.extend(result.episodes)
|
||||||
|
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
||||||
|
merged.communities.extend(result.communities)
|
||||||
|
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ from datetime import datetime
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
||||||
from graphiti_core.helpers import semaphore_gather
|
|
||||||
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||||
|
|
||||||
|
|
@ -30,39 +28,6 @@ EPISODE_WINDOW_LEN = 3
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
||||||
if delete_existing:
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
SHOW INDEXES YIELD name
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
index_names = [record['name'] for record in records]
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
driver.execute_query(
|
|
||||||
"""DROP INDEX $name""",
|
|
||||||
name=name,
|
|
||||||
)
|
|
||||||
for name in index_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
||||||
|
|
||||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
||||||
|
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
||||||
|
|
||||||
await semaphore_gather(
|
|
||||||
*[
|
|
||||||
driver.execute_query(
|
|
||||||
query,
|
|
||||||
)
|
|
||||||
for query in index_queries
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
async with driver.session() as session:
|
async with driver.session() as session:
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue