Add support for falkordb (#575)

* [wip] add support for falkordb

* updates

* fix-async

* progress

* fix-issues

* rm-date-handler

* red-code

* rm-uns-try

* fix-exm

* rm-un-lines

* fix-comments

* fix-se-utils

* fix-falkor-readme

* fix-falkor-cosine-score

* update-falkor-ver

* fix-vec-sim

* min-updates

* make format

* update graph driver abstraction

* poetry lock

* updates

* linter

* Update graphiti_core/search/search_utils.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: Dudi Zimberknopf <zimber.dudi@gmail.com>
Co-authored-by: Gal Shubeli <galshubeli93@gmail.com>
Co-authored-by: Gal Shubeli <124919062+galshubeli@users.noreply.github.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-06-13 12:06:57 -04:00 committed by GitHub
parent 3d7e1a4b79
commit 14146dc46f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 1131 additions and 348 deletions

View file

@ -1,8 +1,17 @@
OPENAI_API_KEY=
# Neo4j database connection
NEO4J_URI=
NEO4J_PORT=
NEO4J_USER=
NEO4J_PASSWORD=
# FalkorDB database connection
FALKORDB_URI=
FALKORDB_PORT=
FALKORDB_USER=
FALKORDB_PASSWORD=
DEFAULT_DATABASE=
USE_PARALLEL_RUNTIME=
SEMAPHORE_LIMIT=

View file

@ -64,10 +64,11 @@ Once you've found an issue tagged with "good first issue" or "help wanted," or p
export TEST_OPENAI_API_KEY=...
export TEST_OPENAI_MODEL=...
export TEST_ANTHROPIC_API_KEY=...
export NEO4J_URI=neo4j://...
export NEO4J_USER=...
export NEO4J_PASSWORD=...
# For Neo4j
export TEST_URI=neo4j://...
export TEST_USER=...
export TEST_PASSWORD=...
```
## Making Changes

View file

@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
Requirements:
- Python 3.10 or higher
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
- OpenAI API key (for LLM inference and embedding)
> [!IMPORTANT]

View file

@ -76,9 +76,7 @@ async def main():
group_id = str(uuid4())
for i, message in enumerate(messages[3:14]):
episodes = await client.retrieve_episodes(
message.actual_timestamp, 3, group_ids=['podcast']
)
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
episode_uuids = [episode.uuid for episode in episodes]
await client.add_episode(

View file

@ -2,7 +2,7 @@
This example demonstrates the basic functionality of Graphiti, including:
1. Connecting to a Neo4j database
1. Connecting to a Neo4j or FalkorDB database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph
4. Searching the graph with semantic and keyword matching
@ -11,10 +11,14 @@ This example demonstrates the basic functionality of Graphiti, including:
## Prerequisites
- Neo4j Desktop installed and running
- A local DBMS created and started in Neo4j Desktop
- Python 3.9+
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
- Python 3.9+
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
- **For Neo4j**:
- Neo4j Desktop installed and running
- A local DBMS created and started in Neo4j Desktop
- **For FalkorDB**:
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
## Setup Instructions
@ -34,17 +38,23 @@ export OPENAI_API_KEY=your_openai_api_key
export NEO4J_URI=bolt://localhost:7687
export NEO4J_USER=neo4j
export NEO4J_PASSWORD=password
# Optional FalkorDB connection parameters (defaults shown)
export FALKORDB_URI=falkor://localhost:6379
```
3. Run the example:
```bash
python quickstart.py
python quickstart_neo4j.py
# For FalkorDB
python quickstart_falkordb.py
```
## What This Example Demonstrates
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance

View file

@ -0,0 +1,240 @@
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to FalkorDB database
#################################################
# Configure logging
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
# FalkorDB connection parameters
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
if not falkor_uri:
raise ValueError('FALKORDB_URI must be set')
async def main():
#################################################
# INITIALIZATION
#################################################
# Connect to FalkorDB and set up Graphiti indices
# This is required before using other Graphiti
# functionality
#################################################
# Initialize Graphiti with FalkorDB connection
graphiti = Graphiti(falkor_uri)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES
#################################################
# Episodes are the primary units of information
# in Graphiti. They can be text or structured JSON
# and are automatically processed to extract entities
# and relationships.
#################################################
# Example: Add Episodes
# Episodes list containing both text and JSON episodes
episodes = [
{
'content': 'Kamala Harris is the Attorney General of California. She was previously '
'the district attorney for San Francisco.',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': 'As AG, Harris was in office from January 3, 2011 January 3, 2017',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'state': 'California',
'previous_role': 'Lieutenant Governor',
'previous_location': 'San Francisco',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'term_start': 'January 7, 2019',
'term_end': 'Present',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
]
# Add episodes to the graph
for i, episode in enumerate(episodes):
await graphiti.add_episode(
name=f'Freakonomics Radio {i}',
episode_body=episode['content']
if isinstance(episode['content'], str)
else json.dumps(episode['content']),
source=episode['type'],
source_description=episode['description'],
reference_time=datetime.now(timezone.utc),
)
print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
#################################################
# BASIC SEARCH
#################################################
# The simplest way to retrieve relationships (edges)
# from Graphiti is using the search method, which
# performs a hybrid search combining semantic
# similarity and BM25 text retrieval.
#################################################
# Perform a hybrid search combining semantic similarity and BM25 retrieval
print("\nSearching for: 'Who was the California Attorney General?'")
results = await graphiti.search('Who was the California Attorney General?')
# Print search results
print('\nSearch Results:')
for result in results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
#################################################
# CENTER NODE SEARCH
#################################################
# For more contextually relevant results, you can
# use a center node to rerank search results based
# on their graph distance to a specific node
#################################################
# Use the top search result's UUID as the center node for reranking
if results and len(results) > 0:
# Get the source node UUID from the top result
center_node_uuid = results[0].source_node_uuid
print('\nReranking search results based on graph distance:')
print(f'Using center node UUID: {center_node_uuid}')
reranked_results = await graphiti.search(
'Who was the California Attorney General?', center_node_uuid=center_node_uuid
)
# Print reranked search results
print('\nReranked Search Results:')
for result in reranked_results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
else:
print('No results found in the initial search to use as center node.')
#################################################
# NODE SEARCH USING SEARCH RECIPES
#################################################
# Graphiti provides predefined search recipes
# optimized for different search scenarios.
# Here we use NODE_HYBRID_SEARCH_RRF for retrieving
# nodes directly instead of edges.
#################################################
# Example: Perform a node search using _search method with standard recipes
print(
'\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
)
# Use a predefined search configuration recipe and modify its limit
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
node_search_config.limit = 5 # Limit to 5 results
# Execute the node search
node_search_results = await graphiti._search(
query='California Governor',
config=node_search_config,
)
# Print node search results
print('\nNode Search Results:')
for node in node_search_results.nodes:
print(f'Node UUID: {node.uuid}')
print(f'Node Name: {node.name}')
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
print(f'Content Summary: {node_summary}')
print(f'Node Labels: {", ".join(node.labels)}')
print(f'Created At: {node.created_at}')
if hasattr(node, 'attributes') and node.attributes:
print('Attributes:')
for key, value in node.attributes.items():
print(f' {key}: {value}')
print('---')
finally:
#################################################
# CLEANUP
#################################################
# Always close the connection to FalkorDB when
# finished to properly release resources
#################################################
# Close the connection
await graphiti.close()
print('\nConnection closed')
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,17 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']

View file

@ -0,0 +1,81 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Any
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
class GraphDriverSession(ABC):
@abstractmethod
async def run(self, query: str, **kwargs: Any) -> Any:
raise NotImplementedError()
class GraphDriver(ABC):
provider: str
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
raise NotImplementedError()
@abstractmethod
def session(self, database: str) -> GraphDriverSession:
raise NotImplementedError()
@abstractmethod
def close(self) -> None:
raise NotImplementedError()
@abstractmethod
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
raise NotImplementedError()
# class GraphDriver:
# _driver: GraphClient
#
# def __init__(
# self,
# uri: str,
# user: str,
# password: str,
# ):
# if uri.startswith('falkor'):
# # FalkorDB
# self._driver = FalkorClient(uri, user, password)
# self.provider = 'falkordb'
# else:
# # Neo4j
# self._driver = Neo4jClient(uri, user, password)
# self.provider = 'neo4j'
#
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
# return self._driver.execute_query(cypher_query_, **kwargs)
#
# async def close(self):
# return await self._driver.close()
#
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
# return self._driver.delete_all_indexes(database_)
#
# def session(self, database: str) -> GraphClientSession:
# return self._driver.session(database)

View file

@ -0,0 +1,132 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Coroutine
from datetime import datetime
from typing import Any
from falkordb import Graph as FalkorGraph
from falkordb.asyncio import FalkorDB
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
class FalkorClientSession(GraphDriverSession):
def __init__(self, graph: FalkorGraph):
self.graph = graph
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
# No cleanup needed for Falkor, but method must exist
pass
async def close(self):
# No explicit close needed for FalkorDB, but method must exist
pass
async def execute_write(self, func, *args, **kwargs):
# Directly await the provided async function with `self` as the transaction/session
return await func(self, *args, **kwargs)
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any:
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
if isinstance(cypher_query_, list):
for cypher, params in cypher_query_:
params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher), params)
else:
params = dict(kwargs)
params = convert_datetimes_to_strings(params)
await self.graph.query(str(cypher_query_), params)
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
return None
class FalkorDriver(GraphDriver):
provider: str = 'falkordb'
def __init__(
self,
uri: str,
user: str,
password: str,
):
super().__init__()
if user and password:
uri_parts = uri.split('://', 1)
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
self.client = FalkorDB.from_url(
url=uri,
)
def _get_graph(self, graph_name: str) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
if graph_name is None:
graph_name = 'DEFAULT_DATABASE'
return self.client.select_graph(graph_name)
async def execute_query(self, cypher_query_, **kwargs: Any):
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
graph = self._get_graph(graph_name)
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
params = convert_datetimes_to_strings(dict(kwargs))
try:
result = await graph.query(cypher_query_, params)
except Exception as e:
if 'already indexed' in str(e):
# check if index already exists
logger.info(f'Index already exists: {e}')
return None
logger.error(f'Error executing FalkorDB query: {e}')
raise
# Convert the result header to a list of strings
header = [h[1].decode('utf-8') for h in result.header]
return result.result_set, header, None
def session(self, database: str) -> GraphDriverSession:
return FalkorClientSession(self._get_graph(database))
async def close(self) -> None:
await self.client.connection.close()
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
return self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,
)
def convert_datetimes_to_strings(obj):
if isinstance(obj, dict):
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_datetimes_to_strings(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_datetimes_to_strings(item) for item in obj)
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj

View file

@ -0,0 +1,60 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Coroutine
from typing import Any, LiteralString
from neo4j import AsyncGraphDatabase
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
logger = logging.getLogger(__name__)
class Neo4jDriver(GraphDriver):
provider: str = 'neo4j'
def __init__(
self,
uri: str,
user: str,
password: str,
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
uri=uri,
auth=(user, password),
)
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
params = kwargs.pop('params', None)
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
return result
def session(self, database: str) -> GraphDriverSession:
return self.client.session(database=database) # type: ignore
async def close(self) -> None:
return await self.client.close()
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,
)

View file

@ -21,10 +21,10 @@ from time import time
from typing import Any
from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
@ -62,9 +62,9 @@ class Edge(BaseModel, ABC):
created_at: datetime
@abstractmethod
async def save(self, driver: AsyncDriver): ...
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: AsyncDriver):
async def delete(self, driver: GraphDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
@ -87,11 +87,11 @@ class Edge(BaseModel, ABC):
return False
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
EPISODIC_EDGE_SAVE,
episode_uuid=self.source_node_uuid,
@ -102,12 +102,12 @@ class EpisodicEdge(Edge):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
@ -130,7 +130,7 @@ class EpisodicEdge(Edge):
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@ -156,7 +156,7 @@ class EpisodicEdge(Edge):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -226,7 +226,7 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: AsyncDriver):
async def load_fact_embedding(self, driver: GraphDriver):
query: LiteralString = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
@ -240,7 +240,7 @@ class EntityEdge(Edge):
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
edge_data: dict[str, Any] = {
'source_uuid': self.source_node_uuid,
'target_uuid': self.target_node_uuid,
@ -264,12 +264,12 @@ class EntityEdge(Edge):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
@ -287,7 +287,7 @@ class EntityEdge(Edge):
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if len(uuids) == 0:
return []
@ -309,7 +309,7 @@ class EntityEdge(Edge):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -342,11 +342,11 @@ class EntityEdge(Edge):
return edges
@classmethod
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@ -359,7 +359,7 @@ class EntityEdge(Edge):
class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
COMMUNITY_EDGE_SAVE,
community_uuid=self.source_node_uuid,
@ -370,12 +370,12 @@ class CommunityEdge(Edge):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
@ -396,7 +396,7 @@ class CommunityEdge(Edge):
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
@ -420,7 +420,7 @@ class CommunityEdge(Edge):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
created_at=parse_db_date(record['created_at']),
)
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
created_at=record['created_at'].to_native(),
created_at=parse_db_date(record['created_at']),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
created_at=parse_db_date(record['created_at']),
)

View file

@ -0,0 +1,147 @@
"""
Database query utilities for different graph database backends.
This module provides database-agnostic query generation for Neo4j and FalkorDB,
supporting index creation, fulltext search, and bulk operations.
"""
from typing_extensions import LiteralString
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
)
# Mapping from Neo4j fulltext index names to FalkorDB node labels
NEO4J_TO_FALKORDB_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO',
}
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
if db_type == 'falkordb':
return [
# Entity node
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
# Episodic node
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
# Community node
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
# RELATES_TO edge
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
# MENTIONS edge
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
# HAS_MEMBER edge
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
else:
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
if db_type == 'falkordb':
return [
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
]
else:
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
else:
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
else:
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
if db_type == 'falkordb':
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
else:
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
queries = []
for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
else:
return ENTITY_NODE_SAVE_BULK
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
if db_type == 'falkordb':
return """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
WITH r, edge
RETURN edge.uuid AS uuid"""
else:
return ENTITY_EDGE_SAVE_BULK

View file

@ -19,12 +19,13 @@ from datetime import datetime
from time import time
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
@ -94,12 +95,13 @@ class Graphiti:
def __init__(
self,
uri: str,
user: str,
password: str,
user: str = None,
password: str = None,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
graph_driver: GraphDriver = None,
):
"""
Initialize a Graphiti instance.
@ -137,7 +139,9 @@ class Graphiti:
Make sure to set the OPENAI_API_KEY environment variable before initializing
Graphiti if you're using the default OpenAIClient.
"""
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
if llm_client:

View file

@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from neo4j import AsyncDriver
from pydantic import BaseModel, ConfigDict
from graphiti_core.cross_encoder import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.embedder import EmbedderClient
from graphiti_core.llm_client import LLMClient
class GraphitiClients(BaseModel):
driver: AsyncDriver
driver: GraphDriver
llm_client: LLMClient
embedder: EmbedderClient
cross_encoder: CrossEncoderClient

View file

@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
)
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
return (
neo_date.to_native()
if isinstance(neo_date, neo4j_time.DateTime)
else datetime.fromisoformat(neo_date)
if neo_date
else None
)
def lucene_sanitize(query: str) -> str:

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

View file

@ -22,13 +22,13 @@ from time import time
from typing import Any
from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
created_at: datetime = Field(default_factory=lambda: utc_now())
@abstractmethod
async def save(self, driver: AsyncDriver): ...
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: AsyncDriver):
async def delete(self, driver: GraphDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
return False
@classmethod
async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
return 'SUCCESS'
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
class EpisodicNode(Node):
@ -150,7 +150,7 @@ class EpisodicNode(Node):
default_factory=list,
)
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
EPISODIC_NODE_SAVE,
uuid=self.uuid,
@ -165,12 +165,12 @@ class EpisodicNode(Node):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
@ -197,7 +197,7 @@ class EpisodicNode(Node):
return episodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
@ -224,7 +224,7 @@ class EpisodicNode(Node):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -263,7 +263,7 @@ class EpisodicNode(Node):
return episodes
@classmethod
async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
@ -304,7 +304,7 @@ class EntityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: AsyncDriver):
async def load_name_embedding(self, driver: GraphDriver):
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
@ -318,7 +318,7 @@ class EntityNode(Node):
self.name_embedding = records[0]['name_embedding']
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
entity_data: dict[str, Any] = {
'uuid': self.uuid,
'name': self.name,
@ -337,16 +337,16 @@ class EntityNode(Node):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
@ -364,7 +364,7 @@ class EntityNode(Node):
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.uuid IN $uuids
@ -382,7 +382,7 @@ class EntityNode(Node):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -416,7 +416,7 @@ class CommunityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: AsyncDriver):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
COMMUNITY_NODE_SAVE,
uuid=self.uuid,
@ -428,7 +428,7 @@ class CommunityNode(Node):
database_=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@ -441,7 +441,7 @@ class CommunityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: AsyncDriver):
async def load_name_embedding(self, driver: GraphDriver):
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
@ -456,7 +456,7 @@ class CommunityNode(Node):
self.name_embedding = records[0]['name_embedding']
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community {uuid: $uuid})
@ -480,7 +480,7 @@ class CommunityNode(Node):
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.uuid IN $uuids
@ -503,7 +503,7 @@ class CommunityNode(Node):
@classmethod
async def get_by_group_ids(
cls,
driver: AsyncDriver,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
@ -542,8 +542,8 @@ class CommunityNode(Node):
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
return EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
created_at=parse_db_date(record['created_at']).timestamp(),
valid_at=(parse_db_date(record['valid_at'])),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
name=record['name'],
group_id=record['group_id'],
labels=record['labels'],
created_at=record['created_at'].to_native(),
created_at=parse_db_date(record['created_at']),
summary=record['summary'],
attributes=record['attributes'],
)
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
created_at=record['created_at'].to_native(),
created_at=parse_db_date(record['created_at']),
summary=record['summary'],
)

View file

@ -18,9 +18,8 @@ import logging
from collections import defaultdict
from time import time
from neo4j import AsyncDriver
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
@ -94,7 +93,7 @@ async def search(
)
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
group_ids = group_ids if group_ids and group_ids != [''] else None
edges, nodes, episodes, communities = await semaphore_gather(
edge_search(
driver,
@ -160,7 +159,7 @@ async def search(
async def edge_search(
driver: AsyncDriver,
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
@ -174,7 +173,6 @@ async def edge_search(
) -> list[EntityEdge]:
if config is None:
return []
search_results: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
@ -261,7 +259,7 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
@ -275,7 +273,6 @@ async def node_search(
) -> list[EntityNode]:
if config is None:
return []
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
@ -344,7 +341,7 @@ async def node_search(
async def episode_search(
driver: AsyncDriver,
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
_query_vector: list[float],
@ -356,7 +353,6 @@ async def episode_search(
) -> list[EpisodicNode]:
if config is None:
return []
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
@ -392,7 +388,7 @@ async def episode_search(
async def community_search(
driver: AsyncDriver,
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],

View file

@ -20,11 +20,16 @@ from time import time
from typing import Any
import numpy as np
from neo4j import AsyncDriver, Query
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import (
get_nodes_query,
get_relationships_query,
get_vector_cosine_func_query,
)
from graphiti_core.helpers import (
DEFAULT_DATABASE,
RUNTIME_QUERY,
@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
def fulltext_query(query: str, group_ids: list[str] | None = None):
group_ids_filter_list = (
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
async def get_episodes_by_mentions(
driver: AsyncDriver,
driver: GraphDriver,
nodes: list[EntityNode],
edges: list[EntityEdge],
limit: int = RELEVANT_SCHEMA_LIMIT,
@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
query = """
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
@ -106,7 +111,10 @@ async def get_mentioned_nodes(
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
"""
records, _, _ = await driver.execute_query(
query,
uuids=episode_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
@ -118,11 +126,11 @@ async def get_mentioned_nodes(
async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
driver: GraphDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
query = """
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
@ -130,7 +138,10 @@ async def get_communities_by_nodes(
c.name AS name,
c.created_at AS created_at,
c.summary AS summary
""",
"""
records, _, _ = await driver.execute_query(
query,
uuids=node_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
@ -142,7 +153,7 @@ async def get_communities_by_nodes(
async def edge_fulltext_search(
driver: AsyncDriver,
driver: GraphDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
@ -155,34 +166,35 @@ async def edge_fulltext_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
cypher_query = Query(
"""
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
YIELD relationship AS rel, score
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE r.group_id IN $group_ids"""
query = (
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query')
+ """
YIELD relationship AS rel, score
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE r.group_id IN $group_ids """
+ filter_query
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC LIMIT $limit
"""
+ """
WITH r, score, startNode(r) AS n, endNode(r) AS m
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
cypher_query,
filter_params,
query,
params=filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
@ -196,7 +208,7 @@ async def edge_fulltext_search(
async def edge_similarity_search(
driver: AsyncDriver,
driver: GraphDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
@ -224,36 +236,38 @@ async def edge_similarity_search(
if target_node_uuid is not None:
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
query: LiteralString = (
query = (
RUNTIME_QUERY
+ """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC
LIMIT $limit
+ """
WITH DISTINCT r, """
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
records, header, _ = await driver.execute_query(
query,
query_params,
params=query_params,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
@ -264,13 +278,16 @@ async def edge_similarity_search(
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
edges = [get_entity_edge_from_record(record) for record in records]
return edges
async def edge_bfs_search(
driver: AsyncDriver,
driver: GraphDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
search_filter: SearchFilters,
@ -282,14 +299,14 @@ async def edge_bfs_search(
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query = Query(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
WHERE r.uuid = rel.uuid
"""
+ filter_query
+ """
RETURN DISTINCT
@ -311,7 +328,7 @@ async def edge_bfs_search(
records, _, _ = await driver.execute_query(
query,
filter_params,
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
@ -325,7 +342,7 @@ async def edge_bfs_search(
async def node_fulltext_search(
driver: AsyncDriver,
driver: GraphDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
@ -335,38 +352,41 @@ async def node_fulltext_search(
fuzzy_query = fulltext_query(query, group_ids)
if fuzzy_query == '':
return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query = (
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
+ """
YIELD node AS n, score
WITH n, score
LIMIT $limit
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
"""
)
records, _, _ = await driver.execute_query(
records, header, _ = await driver.execute_query(
query,
filter_params,
params=filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def node_similarity_search(
driver: AsyncDriver,
driver: GraphDriver,
search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None,
@ -384,22 +404,28 @@ async def node_similarity_search(
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)
records, _, _ = await driver.execute_query(
query = (
RUNTIME_QUERY
+ """
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ group_filter_query
+ filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score"""
WITH n, """
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score"""
+ ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
""",
query_params,
"""
)
records, header, _ = await driver.execute_query(
query,
params=query_params,
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
@ -407,13 +433,15 @@ async def node_similarity_search(
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def node_bfs_search(
driver: AsyncDriver,
driver: GraphDriver,
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
@ -425,18 +453,21 @@ async def node_bfs_search(
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query(
query = (
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
LIMIT $limit
""",
filter_params,
"""
)
records, _, _ = await driver.execute_query(
query,
params=filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
@ -449,7 +480,7 @@ async def node_bfs_search(
async def episode_fulltext_search(
driver: AsyncDriver,
driver: GraphDriver,
query: str,
_search_filter: SearchFilters,
group_ids: list[str] | None = None,
@ -460,9 +491,9 @@ async def episode_fulltext_search(
if fuzzy_query == '':
return []
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
query = (
get_nodes_query(driver.provider, 'episode_content', '$query')
+ """
YIELD node AS episode, score
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
@ -478,7 +509,11 @@ async def episode_fulltext_search(
e.entity_edges AS entity_edges
ORDER BY score DESC
LIMIT $limit
""",
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
@ -491,7 +526,7 @@ async def episode_fulltext_search(
async def community_fulltext_search(
driver: AsyncDriver,
driver: GraphDriver,
query: str,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
@ -501,9 +536,9 @@ async def community_fulltext_search(
if fuzzy_query == '':
return []
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
query = (
get_nodes_query(driver.provider, 'community_name', '$query')
+ """
YIELD node AS comm, score
RETURN
comm.uuid AS uuid,
@ -513,7 +548,11 @@ async def community_fulltext_search(
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
@ -526,7 +565,7 @@ async def community_fulltext_search(
async def community_similarity_search(
driver: AsyncDriver,
driver: GraphDriver,
search_vector: list[float],
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
@ -540,14 +579,16 @@ async def community_similarity_search(
group_filter_query += 'WHERE comm.group_id IN $group_ids'
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query(
query = (
RUNTIME_QUERY
+ """
MATCH (comm:Community)
"""
+ group_filter_query
+ """
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WITH comm, """
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
@ -557,7 +598,11 @@ async def community_similarity_search(
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
@ -573,7 +618,7 @@ async def community_similarity_search(
async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
driver: GraphDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
@ -590,7 +635,7 @@ async def hybrid_node_search(
A list of text queries to search for.
embeddings : list[list[float]]
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
driver : AsyncDriver
driver : GraphDriver
The Neo4j driver instance for database operations.
group_ids : list[str] | None, optional
The list of group ids to retrieve nodes from.
@ -645,7 +690,7 @@ async def hybrid_node_search(
async def get_relevant_nodes(
driver: AsyncDriver,
driver: GraphDriver,
nodes: list[EntityNode],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
@ -664,29 +709,33 @@ async def get_relevant_nodes(
query = (
RUNTIME_QUERY
+ """UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ """
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
WITH node, n, """
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
+ """ AS score
WHERE score > $min_score
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
"""
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
+ """
YIELD node AS m
WHERE m.group_id = $group_id
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
WITH node,
top_vector_nodes,
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
UNWIND combined_nodes AS combined_node
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
RETURN
node.uuid AS search_node_uuid,
[x IN deduped_nodes | {
@ -714,7 +763,7 @@ async def get_relevant_nodes(
results, _, _ = await driver.execute_query(
query,
query_params,
params=query_params,
nodes=query_nodes,
group_id=group_id,
limit=limit,
@ -736,7 +785,7 @@ async def get_relevant_nodes(
async def get_relevant_edges(
driver: AsyncDriver,
driver: GraphDriver,
edges: list[EntityEdge],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
@ -752,43 +801,47 @@ async def get_relevant_edges(
query = (
RUNTIME_QUERY
+ """UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ """
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
WITH e, edge, """
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
+ """ AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
query_params,
params=query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record) for record in result['matches']
@ -802,7 +855,7 @@ async def get_relevant_edges(
async def get_edge_invalidation_candidates(
driver: AsyncDriver,
driver: GraphDriver,
edges: list[EntityEdge],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates(
query = (
RUNTIME_QUERY
+ """UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ """
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
WITH edge, e, """
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
+ """ AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
query_params,
params=query_params,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
async def node_distance_reranker(
driver: AsyncDriver,
driver: GraphDriver,
node_uuids: list[str],
center_node_uuid: str,
min_score: float = 0,
@ -894,21 +950,22 @@ async def node_distance_reranker(
scores: dict[str, float] = {center_node_uuid: 0.0}
# Find the shortest path to center node
query = Query("""
query = """
UNWIND $node_uuids AS node_uuid
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
RETURN length(p) AS score, node_uuid AS uuid
""")
path_results, _, _ = await driver.execute_query(
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
RETURN 1 AS score, node_uuid AS uuid
"""
results, header, _ = await driver.execute_query(
query,
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
results = [dict(zip(header, row, strict=True)) for row in results]
for result in path_results:
for result in results:
uuid = result['uuid']
score = result['score']
scores[uuid] = score
@ -929,19 +986,18 @@ async def node_distance_reranker(
async def episode_mentions_reranker(
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
query = """
UNWIND $node_uuids AS node_uuid
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
RETURN count(*) AS score, n.uuid AS uuid
""")
"""
results, _, _ = await driver.execute_query(
query,
node_uuids=sorted_uuids,
@ -998,7 +1054,7 @@ def maximal_marginal_relevance(
async def get_embeddings_for_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
driver: GraphDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes(
async def get_embeddings_for_communities(
driver: AsyncDriver, communities: list[CommunityNode]
driver: GraphDriver, communities: list[CommunityNode]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (c:Community)
WHERE c.uuid IN $community_uuids
@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities(
async def get_embeddings_for_edges(
driver: AsyncDriver, edges: list[EntityEdge]
driver: GraphDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]:
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids

View file

@ -20,22 +20,24 @@ from collections import defaultdict
from datetime import datetime
from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel
from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graph_queries import (
get_entity_edge_save_bulk_query,
get_entity_node_save_bulk_query,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
@ -73,7 +75,7 @@ class RawEpisode(BaseModel):
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await semaphore_gather(
*[
@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk(
async def add_nodes_and_edges_bulk(
driver: AsyncDriver,
driver: GraphDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
async with driver.session(database=DEFAULT_DATABASE) as session:
session = driver.session(database=DEFAULT_DATABASE)
try:
await session.execute_write(
add_nodes_and_edges_bulk_tx,
episodic_nodes,
@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk(
entity_nodes,
entity_edges,
embedder,
driver=driver,
)
finally:
await session.close()
async def add_nodes_and_edges_bulk_tx(
tx: AsyncManagedTransaction,
tx: GraphDriverSession,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
driver: GraphDriver,
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
@ -160,11 +167,13 @@ async def add_nodes_and_edges_bulk_tx(
edges.append(edge_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
await tx.run(entity_node_save_bulk, nodes=nodes)
await tx.run(
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
)
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
await tx.run(entity_edge_save_bulk, entity_edges=edges)
async def extract_nodes_and_edges_bulk(
@ -211,7 +220,7 @@ async def extract_nodes_and_edges_bulk(
async def dedupe_nodes_bulk(
driver: AsyncDriver,
driver: GraphDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
@ -247,7 +256,7 @@ async def dedupe_nodes_bulk(
async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# First compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)

View file

@ -2,9 +2,9 @@ import asyncio
import logging
from collections import defaultdict
from neo4j import AsyncDriver
from pydantic import BaseModel
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
@ -26,7 +26,7 @@ class Neighbor(BaseModel):
async def get_community_clusters(
driver: AsyncDriver, group_ids: list[str] | None
driver: GraphDriver, group_ids: list[str] | None
) -> list[list[EntityNode]]:
community_clusters: list[list[EntityNode]] = []
@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
community_candidates: dict[int, int] = defaultdict(int)
for neighbor in neighbors:
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
community_lst = [
(count, community) for community, count in community_candidates.items()
]
@ -194,7 +193,7 @@ async def build_community(
async def build_communities(
driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver, group_ids)
@ -219,7 +218,7 @@ async def build_communities(
return community_nodes, community_edges
async def remove_communities(driver: AsyncDriver):
async def remove_communities(driver: GraphDriver):
await driver.execute_query(
"""
MATCH (c:Community)
@ -230,10 +229,10 @@ async def remove_communities(driver: AsyncDriver):
async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
driver: GraphDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
records, _, _ = driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
@ -251,7 +250,7 @@ async def determine_entity_community(
return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities
records, _, _ = await driver.execute_query(
records, _, _ = driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN
@ -291,7 +290,7 @@ async def determine_entity_community(
async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)

View file

@ -260,7 +260,6 @@ async def resolve_extracted_edges(
driver = clients.driver
llm_client = clients.llm_client
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(

View file

@ -17,9 +17,10 @@ limitations under the License.
import logging
from datetime import datetime, timezone
from neo4j import AsyncDriver
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode
@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if delete_existing:
records, _, _ = await driver.execute_query(
"""
@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
for name in index_names
]
)
range_indices: list[LiteralString] = get_range_indices(driver.provider)
range_indices: list[LiteralString] = [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
fulltext_indices: list[LiteralString] = [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
index_queries: list[LiteralString] = range_indices + fulltext_indices
@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
)
async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async with driver.session(database=DEFAULT_DATABASE) as session:
async def delete_all(tx):
@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
async def retrieve_episodes(
driver: AsyncDriver,
driver: GraphDriver,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str] | None = None,
@ -123,7 +94,7 @@ async def retrieve_episodes(
Retrieve the last n episodic nodes from the graph.
Args:
driver (AsyncDriver): The Neo4j driver instance.
driver (Driver): The Neo4j driver instance.
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
less than or equal to this reference_time will be retrieved. This allows for
querying the graph's state at a specific point in time.
@ -140,8 +111,8 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
"""
+ group_id_filter
+ source_filter
+ """
@ -157,8 +128,7 @@ async def retrieve_episodes(
LIMIT $num_episodes
"""
)
result = await driver.execute_query(
result, _, _ = await driver.execute_query(
query,
reference_time=reference_time,
source=source.name if source is not None else None,
@ -166,6 +136,7 @@ async def retrieve_episodes(
group_ids=group_ids,
database_=DEFAULT_DATABASE,
)
episodes = [
EpisodicNode(
content=record['content'],
@ -179,6 +150,6 @@ async def retrieve_episodes(
name=record['name'],
source_description=record['source_description'],
)
for record in result.records
for record in result
]
return list(reversed(episodes)) # Return in chronological order

View file

@ -326,7 +326,6 @@ async def extract_attributes_from_nodes(
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
updated_nodes: list[EntityNode] = await semaphore_gather(
*[
extract_attributes_from_node(

46
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -332,12 +332,12 @@ version = "5.0.1"
description = "Timeout context manager for asyncio programs"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_version < \"3.11\""
groups = ["main", "dev"]
files = [
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
]
markers = {main = "python_full_version < \"3.11.3\"", dev = "python_version == \"3.10\""}
[[package]]
name = "attrs"
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main", "dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
@ -798,6 +798,20 @@ files = [
[package.extras]
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""]
[[package]]
name = "falkordb"
version = "1.1.2"
description = "Python client for interacting with FalkorDB database"
optional = false
python-versions = "<4.0,>=3.8"
groups = ["main"]
files = [
{file = "falkordb-1.1.2.tar.gz", hash = "sha256:db76c97efe14a56c3d65c61b966a42b874e1c78a8fb6808de3f61f4314b04023"},
]
[package.dependencies]
redis = ">=5.0.1,<6.0.0"
[[package]]
name = "fastjsonschema"
version = "2.21.1"
@ -2665,7 +2679,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
optional = false
python-versions = ">=3.9"
groups = ["dev"]
markers = "platform_python_implementation != \"PyPy\""
files = [
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
@ -3691,6 +3704,25 @@ files = [
[package.dependencies]
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
[[package]]
name = "redis"
version = "5.2.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
[package.extras]
hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]]
name = "referencing"
version = "0.36.2"
@ -4498,7 +4530,7 @@ description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
@ -5356,4 +5388,4 @@ groq = ["groq"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4"
content-hash = "814d067fd2959bfe2db58a22637d86580b66d96f34c433852c67d02089d750ab"
content-hash = "2e02a10a6493f7564b86d5d0d09b4cf718004808e115af39550b9ee87c296fb4"

View file

@ -19,6 +19,7 @@ dependencies = [
"tenacity>=9.0.0",
"numpy>=1.0.0",
"python-dotenv>=1.0.1",
"falkordb (>=1.1.2,<2.0.0)",
]
[project.urls]