diff --git a/.env.example b/.env.example index 2724ecb8..53db5942 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,17 @@ OPENAI_API_KEY= + +# Neo4j database connection NEO4J_URI= NEO4J_PORT= NEO4J_USER= NEO4J_PASSWORD= + +# FalkorDB database connection +FALKORDB_URI= +FALKORDB_PORT= +FALKORDB_USER= +FALKORDB_PASSWORD= + DEFAULT_DATABASE= USE_PARALLEL_RUNTIME= SEMAPHORE_LIMIT= diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index caf46881..b5b7156e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -64,10 +64,11 @@ Once you've found an issue tagged with "good first issue" or "help wanted," or p export TEST_OPENAI_API_KEY=... export TEST_OPENAI_MODEL=... export TEST_ANTHROPIC_API_KEY=... - - export NEO4J_URI=neo4j://... - export NEO4J_USER=... - export NEO4J_PASSWORD=... + + # For Neo4j + export TEST_URI=neo4j://... + export TEST_USER=... + export TEST_PASSWORD=... ``` ## Making Changes diff --git a/README.md b/README.md index d471fb38..3e2f110c 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ Requirements: - Python 3.10 or higher -- Neo4j 5.26 or higher (serves as the embeddings storage backend) +- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend) - OpenAI API key (for LLM inference and embedding) > [!IMPORTANT] diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 146d20e6..660c7f22 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -76,9 +76,7 @@ async def main(): group_id = str(uuid4()) for i, message in enumerate(messages[3:14]): - episodes = await client.retrieve_episodes( - message.actual_timestamp, 3, group_ids=['podcast'] - ) + episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id]) episode_uuids = [episode.uuid for episode in episodes] await client.add_episode( diff --git a/examples/quickstart/README.md b/examples/quickstart/README.md index d192f949..988211ca 100644 --- a/examples/quickstart/README.md +++ b/examples/quickstart/README.md @@ -2,7 +2,7 @@ This example demonstrates the basic functionality of Graphiti, including: -1. Connecting to a Neo4j database +1. Connecting to a Neo4j or FalkorDB database 2. Initializing Graphiti indices and constraints 3. Adding episodes to the graph 4. Searching the graph with semantic and keyword matching @@ -11,10 +11,14 @@ This example demonstrates the basic functionality of Graphiti, including: ## Prerequisites -- Neo4j Desktop installed and running -- A local DBMS created and started in Neo4j Desktop -- Python 3.9+ -- OpenAI API key (set as `OPENAI_API_KEY` environment variable) +- Python 3.9+ +- OpenAI API key (set as `OPENAI_API_KEY` environment variable) +- **For Neo4j**: + - Neo4j Desktop installed and running + - A local DBMS created and started in Neo4j Desktop +- **For FalkorDB**: + - FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup) + ## Setup Instructions @@ -34,17 +38,23 @@ export OPENAI_API_KEY=your_openai_api_key export NEO4J_URI=bolt://localhost:7687 export NEO4J_USER=neo4j export NEO4J_PASSWORD=password + +# Optional FalkorDB connection parameters (defaults shown) +export FALKORDB_URI=falkor://localhost:6379 ``` 3. Run the example: ```bash -python quickstart.py +python quickstart_neo4j.py + +# For FalkorDB +python quickstart_falkordb.py ``` ## What This Example Demonstrates -- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j +- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB - **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges - **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges) - **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance diff --git a/examples/quickstart/quickstart_falkordb.py b/examples/quickstart/quickstart_falkordb.py new file mode 100644 index 00000000..82f6a94e --- /dev/null +++ b/examples/quickstart/quickstart_falkordb.py @@ -0,0 +1,240 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from logging import INFO + +from dotenv import load_dotenv + +from graphiti_core import Graphiti +from graphiti_core.nodes import EpisodeType +from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + +################################################# +# CONFIGURATION +################################################# +# Set up logging and environment variables for +# connecting to FalkorDB database +################################################# + +# Configure logging +logging.basicConfig( + level=INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) +logger = logging.getLogger(__name__) + +load_dotenv() + +# FalkorDB connection parameters +# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/ +falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379') + +if not falkor_uri: + raise ValueError('FALKORDB_URI must be set') + + +async def main(): + ################################################# + # INITIALIZATION + ################################################# + # Connect to FalkorDB and set up Graphiti indices + # This is required before using other Graphiti + # functionality + ################################################# + + # Initialize Graphiti with FalkorDB connection + graphiti = Graphiti(falkor_uri) + + try: + # Initialize the graph database with graphiti's indices. This only needs to be done once. + await graphiti.build_indices_and_constraints() + + ################################################# + # ADDING EPISODES + ################################################# + # Episodes are the primary units of information + # in Graphiti. They can be text or structured JSON + # and are automatically processed to extract entities + # and relationships. + ################################################# + + # Example: Add Episodes + # Episodes list containing both text and JSON episodes + episodes = [ + { + 'content': 'Kamala Harris is the Attorney General of California. She was previously ' + 'the district attorney for San Francisco.', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': { + 'name': 'Gavin Newsom', + 'position': 'Governor', + 'state': 'California', + 'previous_role': 'Lieutenant Governor', + 'previous_location': 'San Francisco', + }, + 'type': EpisodeType.json, + 'description': 'podcast metadata', + }, + { + 'content': { + 'name': 'Gavin Newsom', + 'position': 'Governor', + 'term_start': 'January 7, 2019', + 'term_end': 'Present', + }, + 'type': EpisodeType.json, + 'description': 'podcast metadata', + }, + ] + + # Add episodes to the graph + for i, episode in enumerate(episodes): + await graphiti.add_episode( + name=f'Freakonomics Radio {i}', + episode_body=episode['content'] + if isinstance(episode['content'], str) + else json.dumps(episode['content']), + source=episode['type'], + source_description=episode['description'], + reference_time=datetime.now(timezone.utc), + ) + print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})') + + ################################################# + # BASIC SEARCH + ################################################# + # The simplest way to retrieve relationships (edges) + # from Graphiti is using the search method, which + # performs a hybrid search combining semantic + # similarity and BM25 text retrieval. + ################################################# + + # Perform a hybrid search combining semantic similarity and BM25 retrieval + print("\nSearching for: 'Who was the California Attorney General?'") + results = await graphiti.search('Who was the California Attorney General?') + + # Print search results + print('\nSearch Results:') + for result in results: + print(f'UUID: {result.uuid}') + print(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + print(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + print(f'Valid until: {result.invalid_at}') + print('---') + + ################################################# + # CENTER NODE SEARCH + ################################################# + # For more contextually relevant results, you can + # use a center node to rerank search results based + # on their graph distance to a specific node + ################################################# + + # Use the top search result's UUID as the center node for reranking + if results and len(results) > 0: + # Get the source node UUID from the top result + center_node_uuid = results[0].source_node_uuid + + print('\nReranking search results based on graph distance:') + print(f'Using center node UUID: {center_node_uuid}') + + reranked_results = await graphiti.search( + 'Who was the California Attorney General?', center_node_uuid=center_node_uuid + ) + + # Print reranked search results + print('\nReranked Search Results:') + for result in reranked_results: + print(f'UUID: {result.uuid}') + print(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + print(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + print(f'Valid until: {result.invalid_at}') + print('---') + else: + print('No results found in the initial search to use as center node.') + + ################################################# + # NODE SEARCH USING SEARCH RECIPES + ################################################# + # Graphiti provides predefined search recipes + # optimized for different search scenarios. + # Here we use NODE_HYBRID_SEARCH_RRF for retrieving + # nodes directly instead of edges. + ################################################# + + # Example: Perform a node search using _search method with standard recipes + print( + '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:' + ) + + # Use a predefined search configuration recipe and modify its limit + node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True) + node_search_config.limit = 5 # Limit to 5 results + + # Execute the node search + node_search_results = await graphiti._search( + query='California Governor', + config=node_search_config, + ) + + # Print node search results + print('\nNode Search Results:') + for node in node_search_results.nodes: + print(f'Node UUID: {node.uuid}') + print(f'Node Name: {node.name}') + node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary + print(f'Content Summary: {node_summary}') + print(f'Node Labels: {", ".join(node.labels)}') + print(f'Created At: {node.created_at}') + if hasattr(node, 'attributes') and node.attributes: + print('Attributes:') + for key, value in node.attributes.items(): + print(f' {key}: {value}') + print('---') + + finally: + ################################################# + # CLEANUP + ################################################# + # Always close the connection to FalkorDB when + # finished to properly release resources + ################################################# + + # Close the connection + await graphiti.close() + print('\nConnection closed') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/quickstart/quickstart.py b/examples/quickstart/quickstart_neo4j.py similarity index 100% rename from examples/quickstart/quickstart.py rename to examples/quickstart/quickstart_neo4j.py diff --git a/graphiti_core/driver/__init__.py b/graphiti_core/driver/__init__.py new file mode 100644 index 00000000..05f19311 --- /dev/null +++ b/graphiti_core/driver/__init__.py @@ -0,0 +1,17 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +__all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver'] diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py new file mode 100644 index 00000000..17e3329a --- /dev/null +++ b/graphiti_core/driver/driver.py @@ -0,0 +1,81 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from abc import ABC, abstractmethod +from collections.abc import Coroutine +from typing import Any + +from graphiti_core.helpers import DEFAULT_DATABASE + +logger = logging.getLogger(__name__) + + +class GraphDriverSession(ABC): + @abstractmethod + async def run(self, query: str, **kwargs: Any) -> Any: + raise NotImplementedError() + + +class GraphDriver(ABC): + provider: str + + @abstractmethod + def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: + raise NotImplementedError() + + @abstractmethod + def session(self, database: str) -> GraphDriverSession: + raise NotImplementedError() + + @abstractmethod + def close(self) -> None: + raise NotImplementedError() + + @abstractmethod + def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: + raise NotImplementedError() + + +# class GraphDriver: +# _driver: GraphClient +# +# def __init__( +# self, +# uri: str, +# user: str, +# password: str, +# ): +# if uri.startswith('falkor'): +# # FalkorDB +# self._driver = FalkorClient(uri, user, password) +# self.provider = 'falkordb' +# else: +# # Neo4j +# self._driver = Neo4jClient(uri, user, password) +# self.provider = 'neo4j' +# +# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine: +# return self._driver.execute_query(cypher_query_, **kwargs) +# +# async def close(self): +# return await self._driver.close() +# +# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: +# return self._driver.delete_all_indexes(database_) +# +# def session(self, database: str) -> GraphClientSession: +# return self._driver.session(database) diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py new file mode 100644 index 00000000..6dbd1a0a --- /dev/null +++ b/graphiti_core/driver/falkordb_driver.py @@ -0,0 +1,132 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from collections.abc import Coroutine +from datetime import datetime +from typing import Any + +from falkordb import Graph as FalkorGraph +from falkordb.asyncio import FalkorDB + +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.helpers import DEFAULT_DATABASE + +logger = logging.getLogger(__name__) + + +class FalkorClientSession(GraphDriverSession): + def __init__(self, graph: FalkorGraph): + self.graph = graph + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + # No cleanup needed for Falkor, but method must exist + pass + + async def close(self): + # No explicit close needed for FalkorDB, but method must exist + pass + + async def execute_write(self, func, *args, **kwargs): + # Directly await the provided async function with `self` as the transaction/session + return await func(self, *args, **kwargs) + + async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any: + # FalkorDB does not support argument for Label Set, so it's converted into an array of queries + if isinstance(cypher_query_, list): + for cypher, params in cypher_query_: + params = convert_datetimes_to_strings(params) + await self.graph.query(str(cypher), params) + else: + params = dict(kwargs) + params = convert_datetimes_to_strings(params) + await self.graph.query(str(cypher_query_), params) + # Assuming `graph.query` is async (ideal); otherwise, wrap in executor + return None + + +class FalkorDriver(GraphDriver): + provider: str = 'falkordb' + + def __init__( + self, + uri: str, + user: str, + password: str, + ): + super().__init__() + if user and password: + uri_parts = uri.split('://', 1) + uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}' + + self.client = FalkorDB.from_url( + url=uri, + ) + + def _get_graph(self, graph_name: str) -> FalkorGraph: + # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE" + if graph_name is None: + graph_name = 'DEFAULT_DATABASE' + return self.client.select_graph(graph_name) + + async def execute_query(self, cypher_query_, **kwargs: Any): + graph_name = kwargs.pop('database_', DEFAULT_DATABASE) + graph = self._get_graph(graph_name) + + # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly) + params = convert_datetimes_to_strings(dict(kwargs)) + + try: + result = await graph.query(cypher_query_, params) + except Exception as e: + if 'already indexed' in str(e): + # check if index already exists + logger.info(f'Index already exists: {e}') + return None + logger.error(f'Error executing FalkorDB query: {e}') + raise + + # Convert the result header to a list of strings + header = [h[1].decode('utf-8') for h in result.header] + return result.result_set, header, None + + def session(self, database: str) -> GraphDriverSession: + return FalkorClientSession(self._get_graph(database)) + + async def close(self) -> None: + await self.client.connection.close() + + def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: + return self.execute_query( + 'CALL db.indexes() YIELD name DROP INDEX name', + database_=database_, + ) + + +def convert_datetimes_to_strings(obj): + if isinstance(obj, dict): + return {k: convert_datetimes_to_strings(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_datetimes_to_strings(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(convert_datetimes_to_strings(item) for item in obj) + elif isinstance(obj, datetime): + return obj.isoformat() + else: + return obj diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py new file mode 100644 index 00000000..cf31420a --- /dev/null +++ b/graphiti_core/driver/neo4j_driver.py @@ -0,0 +1,60 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from collections.abc import Coroutine +from typing import Any, LiteralString + +from neo4j import AsyncGraphDatabase + +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.helpers import DEFAULT_DATABASE + +logger = logging.getLogger(__name__) + + +class Neo4jDriver(GraphDriver): + provider: str = 'neo4j' + + def __init__( + self, + uri: str, + user: str, + password: str, + ): + super().__init__() + self.client = AsyncGraphDatabase.driver( + uri=uri, + auth=(user, password), + ) + + async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine: + params = kwargs.pop('params', None) + result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + + return result + + def session(self, database: str) -> GraphDriverSession: + return self.client.session(database=database) # type: ignore + + async def close(self) -> None: + return await self.client.close() + + def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine: + return self.client.execute_query( + 'CALL db.indexes() YIELD name DROP INDEX name', + database_=database_, + ) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 700775f3..f2491e99 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -21,10 +21,10 @@ from time import time from typing import Any from uuid import uuid4 -from neo4j import AsyncDriver from pydantic import BaseModel, Field from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphDriver from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date @@ -62,9 +62,9 @@ class Edge(BaseModel, ABC): created_at: datetime @abstractmethod - async def save(self, driver: AsyncDriver): ... + async def save(self, driver: GraphDriver): ... - async def delete(self, driver: AsyncDriver): + async def delete(self, driver: GraphDriver): result = await driver.execute_query( """ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) @@ -87,11 +87,11 @@ class Edge(BaseModel, ABC): return False @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... class EpisodicEdge(Edge): - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): result = await driver.execute_query( EPISODIC_EDGE_SAVE, episode_uuid=self.source_node_uuid, @@ -102,12 +102,12 @@ class EpisodicEdge(Edge): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved edge to neo4j: {self.uuid}') + logger.debug(f'Saved edge to Graph: {self.uuid}') return result @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) @@ -130,7 +130,7 @@ class EpisodicEdge(Edge): return edges[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) @@ -156,7 +156,7 @@ class EpisodicEdge(Edge): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -226,7 +226,7 @@ class EntityEdge(Edge): return self.fact_embedding - async def load_fact_embedding(self, driver: AsyncDriver): + async def load_fact_embedding(self, driver: GraphDriver): query: LiteralString = """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN e.fact_embedding AS fact_embedding @@ -240,7 +240,7 @@ class EntityEdge(Edge): self.fact_embedding = records[0]['fact_embedding'] - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): edge_data: dict[str, Any] = { 'source_uuid': self.source_node_uuid, 'target_uuid': self.target_node_uuid, @@ -264,12 +264,12 @@ class EntityEdge(Edge): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved edge to neo4j: {self.uuid}') + logger.debug(f'Saved edge to Graph: {self.uuid}') return result @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) @@ -287,7 +287,7 @@ class EntityEdge(Edge): return edges[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): if len(uuids) == 0: return [] @@ -309,7 +309,7 @@ class EntityEdge(Edge): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -342,11 +342,11 @@ class EntityEdge(Edge): return edges @classmethod - async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str): + async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query( @@ -359,7 +359,7 @@ class EntityEdge(Edge): class CommunityEdge(Edge): - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): result = await driver.execute_query( COMMUNITY_EDGE_SAVE, community_uuid=self.source_node_uuid, @@ -370,12 +370,12 @@ class CommunityEdge(Edge): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved edge to neo4j: {self.uuid}') + logger.debug(f'Saved edge to Graph: {self.uuid}') return result @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community) @@ -396,7 +396,7 @@ class CommunityEdge(Edge): return edges[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) @@ -420,7 +420,7 @@ class CommunityEdge(Edge): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: group_id=record['group_id'], source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], - created_at=record['created_at'].to_native(), + created_at=parse_db_date(record['created_at']), ) @@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: name=record['name'], group_id=record['group_id'], episodes=record['episodes'], - created_at=record['created_at'].to_native(), + created_at=parse_db_date(record['created_at']), expired_at=parse_db_date(record['expired_at']), valid_at=parse_db_date(record['valid_at']), invalid_at=parse_db_date(record['invalid_at']), @@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any): group_id=record['group_id'], source_node_uuid=record['source_node_uuid'], target_node_uuid=record['target_node_uuid'], - created_at=record['created_at'].to_native(), + created_at=parse_db_date(record['created_at']), ) diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py new file mode 100644 index 00000000..885ef669 --- /dev/null +++ b/graphiti_core/graph_queries.py @@ -0,0 +1,147 @@ +""" +Database query utilities for different graph database backends. + +This module provides database-agnostic query generation for Neo4j and FalkorDB, +supporting index creation, fulltext search, and bulk operations. +""" + +from typing_extensions import LiteralString + +from graphiti_core.models.edges.edge_db_queries import ( + ENTITY_EDGE_SAVE_BULK, +) +from graphiti_core.models.nodes.node_db_queries import ( + ENTITY_NODE_SAVE_BULK, +) + +# Mapping from Neo4j fulltext index names to FalkorDB node labels +NEO4J_TO_FALKORDB_MAPPING = { + 'node_name_and_summary': 'Entity', + 'community_name': 'Community', + 'episode_content': 'Episodic', + 'edge_name_and_fact': 'RELATES_TO', +} + + +def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]: + if db_type == 'falkordb': + return [ + # Entity node + 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)', + # Episodic node + 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)', + # Community node + 'CREATE INDEX FOR (n:Community) ON (n.uuid)', + # RELATES_TO edge + 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)', + # MENTIONS edge + 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)', + # HAS_MEMBER edge + 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + ] + else: + return [ + 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', + 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', + 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', + 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', + 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', + 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', + 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', + 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', + 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', + 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', + 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', + 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', + 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', + ] + + +def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]: + if db_type == 'falkordb': + return [ + """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""", + """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""", + """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""", + """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""", + ] + else: + return [ + """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS + FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", + """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", + """CREATE FULLTEXT INDEX community_name IF NOT EXISTS + FOR (n:Community) ON EACH [n.name, n.group_id]""", + """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", + ] + + +def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: + if db_type == 'falkordb': + label = NEO4J_TO_FALKORDB_MAPPING[name] + return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" + else: + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' + + +def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str: + if db_type == 'falkordb': + # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity + return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2' + else: + return f'vector.similarity.cosine({vec1}, {vec2})' + + +def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str: + if db_type == 'falkordb': + label = NEO4J_TO_FALKORDB_MAPPING[name] + return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" + else: + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' + + +def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str: + if db_type == 'falkordb': + queries = [] + for node in nodes: + for label in node['labels']: + queries.append( + ( + f""" + UNWIND $nodes AS node + MERGE (n:Entity {{uuid: node.uuid}}) + SET n:{label} + SET n = node + WITH n, node + SET n.name_embedding = vecf32(node.name_embedding) + RETURN n.uuid AS uuid + """, + {'nodes': [node]}, + ) + ) + return queries + else: + return ENTITY_NODE_SAVE_BULK + + +def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str: + if db_type == 'falkordb': + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) + SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, + created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} + WITH r, edge + RETURN edge.uuid AS uuid""" + else: + return ENTITY_EDGE_SAVE_BULK diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 930f66d4..5410283e 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -19,12 +19,13 @@ from datetime import datetime from time import time from dotenv import load_dotenv -from neo4j import AsyncGraphDatabase from pydantic import BaseModel from typing_extensions import LiteralString from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients @@ -94,12 +95,13 @@ class Graphiti: def __init__( self, uri: str, - user: str, - password: str, + user: str = None, + password: str = None, llm_client: LLMClient | None = None, embedder: EmbedderClient | None = None, cross_encoder: CrossEncoderClient | None = None, store_raw_episode_content: bool = True, + graph_driver: GraphDriver = None, ): """ Initialize a Graphiti instance. @@ -137,7 +139,9 @@ class Graphiti: Make sure to set the OPENAI_API_KEY environment variable before initializing Graphiti if you're using the default OpenAIClient. """ - self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + + self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password) + self.database = DEFAULT_DATABASE self.store_raw_episode_content = store_raw_episode_content if llm_client: diff --git a/graphiti_core/graphiti_types.py b/graphiti_core/graphiti_types.py index c765ee63..decdf027 100644 --- a/graphiti_core/graphiti_types.py +++ b/graphiti_core/graphiti_types.py @@ -14,16 +14,16 @@ See the License for the specific language governing permissions and limitations under the License. """ -from neo4j import AsyncDriver from pydantic import BaseModel, ConfigDict from graphiti_core.cross_encoder import CrossEncoderClient +from graphiti_core.driver.driver import GraphDriver from graphiti_core.embedder import EmbedderClient from graphiti_core.llm_client import LLMClient class GraphitiClients(BaseModel): - driver: AsyncDriver + driver: GraphDriver llm_client: LLMClient embedder: EmbedderClient cross_encoder: CrossEncoderClient diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 21c388c2..c7d8cd40 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = ( ) -def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: - return neo_date.to_native() if neo_date else None +def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None: + return ( + neo_date.to_native() + if isinstance(neo_date, neo4j_time.DateTime) + else datetime.fromisoformat(neo_date) + if neo_date + else None + ) def lucene_sanitize(query: str) -> str: diff --git a/graphiti_core/llm_client/__init__.py b/graphiti_core/llm_client/__init__.py index 1472aa6b..376bf33a 100644 --- a/graphiti_core/llm_client/__init__.py +++ b/graphiti_core/llm_client/__init__.py @@ -1,3 +1,19 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + from .client import LLMClient from .config import LLMConfig from .errors import RateLimitError diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 58ec41ce..945a07b2 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -22,13 +22,13 @@ from time import time from typing import Any from uuid import uuid4 -from neo4j import AsyncDriver from pydantic import BaseModel, Field from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphDriver from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError -from graphiti_core.helpers import DEFAULT_DATABASE +from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date from graphiti_core.models.nodes.node_db_queries import ( COMMUNITY_NODE_SAVE, ENTITY_NODE_SAVE, @@ -94,9 +94,9 @@ class Node(BaseModel, ABC): created_at: datetime = Field(default_factory=lambda: utc_now()) @abstractmethod - async def save(self, driver: AsyncDriver): ... + async def save(self, driver: GraphDriver): ... - async def delete(self, driver: AsyncDriver): + async def delete(self, driver: GraphDriver): result = await driver.execute_query( """ MATCH (n:Entity|Episodic|Community {uuid: $uuid}) @@ -119,7 +119,7 @@ class Node(BaseModel, ABC): return False @classmethod - async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str): + async def delete_by_group_id(cls, driver: GraphDriver, group_id: str): await driver.execute_query( """ MATCH (n:Entity|Episodic|Community {group_id: $group_id}) @@ -132,10 +132,10 @@ class Node(BaseModel, ABC): return 'SUCCESS' @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ... + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ... class EpisodicNode(Node): @@ -150,7 +150,7 @@ class EpisodicNode(Node): default_factory=list, ) - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): result = await driver.execute_query( EPISODIC_NODE_SAVE, uuid=self.uuid, @@ -165,12 +165,12 @@ class EpisodicNode(Node): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved Node to neo4j: {self.uuid}') + logger.debug(f'Saved Node to Graph: {self.uuid}') return result @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic {uuid: $uuid}) @@ -197,7 +197,7 @@ class EpisodicNode(Node): return episodes[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic) WHERE e.uuid IN $uuids @@ -224,7 +224,7 @@ class EpisodicNode(Node): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -263,7 +263,7 @@ class EpisodicNode(Node): return episodes @classmethod - async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str): + async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) @@ -304,7 +304,7 @@ class EntityNode(Node): return self.name_embedding - async def load_name_embedding(self, driver: AsyncDriver): + async def load_name_embedding(self, driver: GraphDriver): query: LiteralString = """ MATCH (n:Entity {uuid: $uuid}) RETURN n.name_embedding AS name_embedding @@ -318,7 +318,7 @@ class EntityNode(Node): self.name_embedding = records[0]['name_embedding'] - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): entity_data: dict[str, Any] = { 'uuid': self.uuid, 'name': self.name, @@ -337,16 +337,16 @@ class EntityNode(Node): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved Node to neo4j: {self.uuid}') + logger.debug(f'Saved Node to Graph: {self.uuid}') return result @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -364,7 +364,7 @@ class EntityNode(Node): return nodes[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Entity) WHERE n.uuid IN $uuids @@ -382,7 +382,7 @@ class EntityNode(Node): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -416,7 +416,7 @@ class CommunityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='region summary of member nodes', default_factory=str) - async def save(self, driver: AsyncDriver): + async def save(self, driver: GraphDriver): result = await driver.execute_query( COMMUNITY_NODE_SAVE, uuid=self.uuid, @@ -428,7 +428,7 @@ class CommunityNode(Node): database_=DEFAULT_DATABASE, ) - logger.debug(f'Saved Node to neo4j: {self.uuid}') + logger.debug(f'Saved Node to Graph: {self.uuid}') return result @@ -441,7 +441,7 @@ class CommunityNode(Node): return self.name_embedding - async def load_name_embedding(self, driver: AsyncDriver): + async def load_name_embedding(self, driver: GraphDriver): query: LiteralString = """ MATCH (c:Community {uuid: $uuid}) RETURN c.name_embedding AS name_embedding @@ -456,7 +456,7 @@ class CommunityNode(Node): self.name_embedding = records[0]['name_embedding'] @classmethod - async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ MATCH (n:Community {uuid: $uuid}) @@ -480,7 +480,7 @@ class CommunityNode(Node): return nodes[0] @classmethod - async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Community) WHERE n.uuid IN $uuids @@ -503,7 +503,7 @@ class CommunityNode(Node): @classmethod async def get_by_group_ids( cls, - driver: AsyncDriver, + driver: GraphDriver, group_ids: list[str], limit: int | None = None, uuid_cursor: str | None = None, @@ -542,8 +542,8 @@ class CommunityNode(Node): def get_episodic_node_from_record(record: Any) -> EpisodicNode: return EpisodicNode( content=record['content'], - created_at=record['created_at'].to_native().timestamp(), - valid_at=(record['valid_at'].to_native()), + created_at=parse_db_date(record['created_at']).timestamp(), + valid_at=(parse_db_date(record['valid_at'])), uuid=record['uuid'], group_id=record['group_id'], source=EpisodeType.from_str(record['source']), @@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: name=record['name'], group_id=record['group_id'], labels=record['labels'], - created_at=record['created_at'].to_native(), + created_at=parse_db_date(record['created_at']), summary=record['summary'], attributes=record['attributes'], ) @@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode: name=record['name'], group_id=record['group_id'], name_embedding=record['name_embedding'], - created_at=record['created_at'].to_native(), + created_at=parse_db_date(record['created_at']), summary=record['summary'], ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 1ed26420..f68df779 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -18,9 +18,8 @@ import logging from collections import defaultdict from time import time -from neo4j import AsyncDriver - from graphiti_core.cross_encoder.client import CrossEncoderClient +from graphiti_core.driver.driver import GraphDriver from graphiti_core.edges import EntityEdge from graphiti_core.errors import SearchRerankerError from graphiti_core.graphiti_types import GraphitiClients @@ -94,7 +93,7 @@ async def search( ) # if group_ids is empty, set it to None - group_ids = group_ids if group_ids else None + group_ids = group_ids if group_ids and group_ids != [''] else None edges, nodes, episodes, communities = await semaphore_gather( edge_search( driver, @@ -160,7 +159,7 @@ async def search( async def edge_search( - driver: AsyncDriver, + driver: GraphDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], @@ -174,7 +173,6 @@ async def edge_search( ) -> list[EntityEdge]: if config is None: return [] - search_results: list[list[EntityEdge]] = list( await semaphore_gather( *[ @@ -261,7 +259,7 @@ async def edge_search( async def node_search( - driver: AsyncDriver, + driver: GraphDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], @@ -275,7 +273,6 @@ async def node_search( ) -> list[EntityNode]: if config is None: return [] - search_results: list[list[EntityNode]] = list( await semaphore_gather( *[ @@ -344,7 +341,7 @@ async def node_search( async def episode_search( - driver: AsyncDriver, + driver: GraphDriver, cross_encoder: CrossEncoderClient, query: str, _query_vector: list[float], @@ -356,7 +353,6 @@ async def episode_search( ) -> list[EpisodicNode]: if config is None: return [] - search_results: list[list[EpisodicNode]] = list( await semaphore_gather( *[ @@ -392,7 +388,7 @@ async def episode_search( async def community_search( - driver: AsyncDriver, + driver: GraphDriver, cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 707704aa..6ef2cd32 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -20,11 +20,16 @@ from time import time from typing import Any import numpy as np -from neo4j import AsyncDriver, Query from numpy._typing import NDArray from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphDriver from graphiti_core.edges import EntityEdge, get_entity_edge_from_record +from graphiti_core.graph_queries import ( + get_nodes_query, + get_relationships_query, + get_vector_cosine_func_query, +) from graphiti_core.helpers import ( DEFAULT_DATABASE, RUNTIME_QUERY, @@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32 def fulltext_query(query: str, group_ids: list[str] | None = None): group_ids_filter_list = ( - [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else [] + [f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else [] ) group_ids_filter = '' for f in group_ids_filter_list: @@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None): async def get_episodes_by_mentions( - driver: AsyncDriver, + driver: GraphDriver, nodes: list[EntityNode], edges: list[EntityEdge], limit: int = RELEVANT_SCHEMA_LIMIT, @@ -92,11 +97,11 @@ async def get_episodes_by_mentions( async def get_mentioned_nodes( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: GraphDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - records, _, _ = await driver.execute_query( - """ + + query = """ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids RETURN DISTINCT n.uuid As uuid, @@ -106,7 +111,10 @@ async def get_mentioned_nodes( n.summary AS summary, labels(n) AS labels, properties(n) AS attributes - """, + """ + + records, _, _ = await driver.execute_query( + query, uuids=episode_uuids, database_=DEFAULT_DATABASE, routing_='r', @@ -118,11 +126,11 @@ async def get_mentioned_nodes( async def get_communities_by_nodes( - driver: AsyncDriver, nodes: list[EntityNode] + driver: GraphDriver, nodes: list[EntityNode] ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - records, _, _ = await driver.execute_query( - """ + + query = """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids RETURN DISTINCT c.uuid As uuid, @@ -130,7 +138,10 @@ async def get_communities_by_nodes( c.name AS name, c.created_at AS created_at, c.summary AS summary - """, + """ + + records, _, _ = await driver.execute_query( + query, uuids=node_uuids, database_=DEFAULT_DATABASE, routing_='r', @@ -142,7 +153,7 @@ async def get_communities_by_nodes( async def edge_fulltext_search( - driver: AsyncDriver, + driver: GraphDriver, query: str, search_filter: SearchFilters, group_ids: list[str] | None = None, @@ -155,34 +166,35 @@ async def edge_fulltext_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) - cypher_query = Query( - """ - CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit}) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - WHERE r.group_id IN $group_ids""" + query = ( + get_relationships_query(driver.provider, 'edge_name_and_fact', '$query') + + """ + YIELD relationship AS rel, score + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + WHERE r.group_id IN $group_ids """ + filter_query - + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - ORDER BY score DESC LIMIT $limit - """ + + """ + WITH r, score, startNode(r) AS n, endNode(r) AS m + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at, + properties(r) AS attributes + ORDER BY score DESC LIMIT $limit + """ ) records, _, _ = await driver.execute_query( - cypher_query, - filter_params, + query, + params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, @@ -196,7 +208,7 @@ async def edge_fulltext_search( async def edge_similarity_search( - driver: AsyncDriver, + driver: GraphDriver, search_vector: list[float], source_node_uuid: str | None, target_node_uuid: str | None, @@ -224,36 +236,38 @@ async def edge_similarity_search( if target_node_uuid is not None: group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])' - query: LiteralString = ( + query = ( RUNTIME_QUERY + """ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + """ + group_filter_query + filter_query - + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score - WHERE score > $min_score - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - startNode(r).uuid AS source_node_uuid, - endNode(r).uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - ORDER BY score DESC - LIMIT $limit + + """ + WITH DISTINCT r, """ + + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider) + + """ AS score + WHERE score > $min_score + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + startNode(r).uuid AS source_node_uuid, + endNode(r).uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at, + properties(r) AS attributes + ORDER BY score DESC + LIMIT $limit """ ) - - records, _, _ = await driver.execute_query( + records, header, _ = await driver.execute_query( query, - query_params, + params=query_params, search_vector=search_vector, source_uuid=source_node_uuid, target_uuid=target_node_uuid, @@ -264,13 +278,16 @@ async def edge_similarity_search( routing_='r', ) + if driver.provider == 'falkordb': + records = [dict(zip(header, row, strict=True)) for row in records] + edges = [get_entity_edge_from_record(record) for record in records] return edges async def edge_bfs_search( - driver: AsyncDriver, + driver: GraphDriver, bfs_origin_node_uuids: list[str] | None, bfs_max_depth: int, search_filter: SearchFilters, @@ -282,14 +299,14 @@ async def edge_bfs_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) - query = Query( + query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + WHERE r.uuid = rel.uuid + """ + filter_query + """ RETURN DISTINCT @@ -311,7 +328,7 @@ async def edge_bfs_search( records, _, _ = await driver.execute_query( query, - filter_params, + params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, limit=limit, @@ -325,7 +342,7 @@ async def edge_bfs_search( async def node_fulltext_search( - driver: AsyncDriver, + driver: GraphDriver, query: str, search_filter: SearchFilters, group_ids: list[str] | None = None, @@ -335,38 +352,41 @@ async def node_fulltext_search( fuzzy_query = fulltext_query(query, group_ids) if fuzzy_query == '': return [] - filter_query, filter_params = node_search_filter_query_constructor(search_filter) query = ( + get_nodes_query(driver.provider, 'node_name_and_summary', '$query') + + """ + YIELD node AS n, score + WITH n, score + LIMIT $limit + WHERE n:Entity """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + filter_query + ENTITY_NODE_RETURN + """ ORDER BY score DESC """ ) - - records, _, _ = await driver.execute_query( + records, header, _ = await driver.execute_query( query, - filter_params, + params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, database_=DEFAULT_DATABASE, routing_='r', ) + if driver.provider == 'falkordb': + records = [dict(zip(header, row, strict=True)) for row in records] + nodes = [get_entity_node_from_record(record) for record in records] return nodes async def node_similarity_search( - driver: AsyncDriver, + driver: GraphDriver, search_vector: list[float], search_filter: SearchFilters, group_ids: list[str] | None = None, @@ -384,22 +404,28 @@ async def node_similarity_search( filter_query, filter_params = node_search_filter_query_constructor(search_filter) query_params.update(filter_params) - records, _, _ = await driver.execute_query( + query = ( RUNTIME_QUERY + """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + group_filter_query + filter_query + """ - WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score - WHERE score > $min_score""" + WITH n, """ + + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider) + + """ AS score + WHERE score > $min_score""" + ENTITY_NODE_RETURN + """ ORDER BY score DESC LIMIT $limit - """, - query_params, + """ + ) + + records, header, _ = await driver.execute_query( + query, + params=query_params, search_vector=search_vector, group_ids=group_ids, limit=limit, @@ -407,13 +433,15 @@ async def node_similarity_search( database_=DEFAULT_DATABASE, routing_='r', ) + if driver.provider == 'falkordb': + records = [dict(zip(header, row, strict=True)) for row in records] nodes = [get_entity_node_from_record(record) for record in records] return nodes async def node_bfs_search( - driver: AsyncDriver, + driver: GraphDriver, bfs_origin_node_uuids: list[str] | None, search_filter: SearchFilters, bfs_max_depth: int, @@ -425,18 +453,21 @@ async def node_bfs_search( filter_query, filter_params = node_search_filter_query_constructor(search_filter) - records, _, _ = await driver.execute_query( + query = ( """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + WHERE n.group_id = origin.group_id + """ + filter_query + ENTITY_NODE_RETURN + """ LIMIT $limit - """, - filter_params, + """ + ) + records, _, _ = await driver.execute_query( + query, + params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, limit=limit, @@ -449,7 +480,7 @@ async def node_bfs_search( async def episode_fulltext_search( - driver: AsyncDriver, + driver: GraphDriver, query: str, _search_filter: SearchFilters, group_ids: list[str] | None = None, @@ -460,9 +491,9 @@ async def episode_fulltext_search( if fuzzy_query == '': return [] - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit}) + query = ( + get_nodes_query(driver.provider, 'episode_content', '$query') + + """ YIELD node AS episode, score MATCH (e:Episodic) WHERE e.uuid = episode.uuid @@ -478,7 +509,11 @@ async def episode_fulltext_search( e.entity_edges AS entity_edges ORDER BY score DESC LIMIT $limit - """, + """ + ) + + records, _, _ = await driver.execute_query( + query, query=fuzzy_query, group_ids=group_ids, limit=limit, @@ -491,7 +526,7 @@ async def episode_fulltext_search( async def community_fulltext_search( - driver: AsyncDriver, + driver: GraphDriver, query: str, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, @@ -501,9 +536,9 @@ async def community_fulltext_search( if fuzzy_query == '': return [] - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit}) + query = ( + get_nodes_query(driver.provider, 'community_name', '$query') + + """ YIELD node AS comm, score RETURN comm.uuid AS uuid, @@ -513,7 +548,11 @@ async def community_fulltext_search( comm.summary AS summary ORDER BY score DESC LIMIT $limit - """, + """ + ) + + records, _, _ = await driver.execute_query( + query, query=fuzzy_query, group_ids=group_ids, limit=limit, @@ -526,7 +565,7 @@ async def community_fulltext_search( async def community_similarity_search( - driver: AsyncDriver, + driver: GraphDriver, search_vector: list[float], group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, @@ -540,14 +579,16 @@ async def community_similarity_search( group_filter_query += 'WHERE comm.group_id IN $group_ids' query_params['group_ids'] = group_ids - records, _, _ = await driver.execute_query( + query = ( RUNTIME_QUERY + """ MATCH (comm:Community) """ + group_filter_query + """ - WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score + WITH comm, """ + + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider) + + """ AS score WHERE score > $min_score RETURN comm.uuid As uuid, @@ -557,7 +598,11 @@ async def community_similarity_search( comm.summary AS summary ORDER BY score DESC LIMIT $limit - """, + """ + ) + + records, _, _ = await driver.execute_query( + query, search_vector=search_vector, group_ids=group_ids, limit=limit, @@ -573,7 +618,7 @@ async def community_similarity_search( async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], - driver: AsyncDriver, + driver: GraphDriver, search_filter: SearchFilters, group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, @@ -590,7 +635,7 @@ async def hybrid_node_search( A list of text queries to search for. embeddings : list[list[float]] A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed. - driver : AsyncDriver + driver : GraphDriver The Neo4j driver instance for database operations. group_ids : list[str] | None, optional The list of group ids to retrieve nodes from. @@ -645,7 +690,7 @@ async def hybrid_node_search( async def get_relevant_nodes( - driver: AsyncDriver, + driver: GraphDriver, nodes: list[EntityNode], search_filter: SearchFilters, min_score: float = DEFAULT_MIN_SCORE, @@ -664,29 +709,33 @@ async def get_relevant_nodes( query = ( RUNTIME_QUERY - + """UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + + """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ - WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score + WITH node, n, """ + + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider) + + """ AS score WHERE score > $min_score WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids - - CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit}) + """ + + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query') + + """ YIELD node AS m WHERE m.group_id = $group_id WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes - + WITH node, top_vector_nodes, [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes - + WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes - + UNWIND combined_nodes AS combined_node WITH node, collect(DISTINCT combined_node) AS deduped_nodes - + RETURN node.uuid AS search_node_uuid, [x IN deduped_nodes | { @@ -714,7 +763,7 @@ async def get_relevant_nodes( results, _, _ = await driver.execute_query( query, - query_params, + params=query_params, nodes=query_nodes, group_id=group_id, limit=limit, @@ -736,7 +785,7 @@ async def get_relevant_nodes( async def get_relevant_edges( - driver: AsyncDriver, + driver: GraphDriver, edges: list[EntityEdge], search_filter: SearchFilters, min_score: float = DEFAULT_MIN_SCORE, @@ -752,43 +801,47 @@ async def get_relevant_edges( query = ( RUNTIME_QUERY - + """UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + + """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ - WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score - WHERE score > $min_score - WITH edge, e, score - ORDER BY score DESC - RETURN edge.uuid AS search_edge_uuid, - collect({ - uuid: e.uuid, - source_node_uuid: startNode(e).uuid, - target_node_uuid: endNode(e).uuid, - created_at: e.created_at, - name: e.name, - group_id: e.group_id, - fact: e.fact, - fact_embedding: e.fact_embedding, - episodes: e.episodes, - expired_at: e.expired_at, - valid_at: e.valid_at, - invalid_at: e.invalid_at, - attributes: properties(e) - })[..$limit] AS matches + WITH e, edge, """ + + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider) + + """ AS score + WHERE score > $min_score + WITH edge, e, score + ORDER BY score DESC + RETURN edge.uuid AS search_edge_uuid, + collect({ + uuid: e.uuid, + source_node_uuid: startNode(e).uuid, + target_node_uuid: endNode(e).uuid, + created_at: e.created_at, + name: e.name, + group_id: e.group_id, + fact: e.fact, + fact_embedding: e.fact_embedding, + episodes: e.episodes, + expired_at: e.expired_at, + valid_at: e.valid_at, + invalid_at: e.invalid_at, + attributes: properties(e) + })[..$limit] AS matches """ ) results, _, _ = await driver.execute_query( query, - query_params, + params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, database_=DEFAULT_DATABASE, routing_='r', ) + relevant_edges_dict: dict[str, list[EntityEdge]] = { result['search_edge_uuid']: [ get_entity_edge_from_record(record) for record in result['matches'] @@ -802,7 +855,7 @@ async def get_relevant_edges( async def get_edge_invalidation_candidates( - driver: AsyncDriver, + driver: GraphDriver, edges: list[EntityEdge], search_filter: SearchFilters, min_score: float = DEFAULT_MIN_SCORE, @@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates( query = ( RUNTIME_QUERY - + """UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + + """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ - WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score - WHERE score > $min_score - WITH edge, e, score - ORDER BY score DESC - RETURN edge.uuid AS search_edge_uuid, - collect({ - uuid: e.uuid, - source_node_uuid: startNode(e).uuid, - target_node_uuid: endNode(e).uuid, - created_at: e.created_at, - name: e.name, - group_id: e.group_id, - fact: e.fact, - fact_embedding: e.fact_embedding, - episodes: e.episodes, - expired_at: e.expired_at, - valid_at: e.valid_at, - invalid_at: e.invalid_at, - attributes: properties(e) - })[..$limit] AS matches + WITH edge, e, """ + + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider) + + """ AS score + WHERE score > $min_score + WITH edge, e, score + ORDER BY score DESC + RETURN edge.uuid AS search_edge_uuid, + collect({ + uuid: e.uuid, + source_node_uuid: startNode(e).uuid, + target_node_uuid: endNode(e).uuid, + created_at: e.created_at, + name: e.name, + group_id: e.group_id, + fact: e.fact, + fact_embedding: e.fact_embedding, + episodes: e.episodes, + expired_at: e.expired_at, + valid_at: e.valid_at, + invalid_at: e.invalid_at, + attributes: properties(e) + })[..$limit] AS matches """ ) results, _, _ = await driver.execute_query( query, - query_params, + params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, @@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st async def node_distance_reranker( - driver: AsyncDriver, + driver: GraphDriver, node_uuids: list[str], center_node_uuid: str, min_score: float = 0, @@ -894,21 +950,22 @@ async def node_distance_reranker( scores: dict[str, float] = {center_node_uuid: 0.0} # Find the shortest path to center node - query = Query(""" + query = """ UNWIND $node_uuids AS node_uuid - MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid}) - RETURN length(p) AS score, node_uuid AS uuid - """) - - path_results, _, _ = await driver.execute_query( + MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid}) + RETURN 1 AS score, node_uuid AS uuid + """ + results, header, _ = await driver.execute_query( query, node_uuids=filtered_uuids, center_uuid=center_node_uuid, database_=DEFAULT_DATABASE, routing_='r', ) + if driver.provider == 'falkordb': + results = [dict(zip(header, row, strict=True)) for row in results] - for result in path_results: + for result in results: uuid = result['uuid'] score = result['score'] scores[uuid] = score @@ -929,19 +986,18 @@ async def node_distance_reranker( async def episode_mentions_reranker( - driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0 + driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0 ) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(node_uuids) scores: dict[str, float] = {} # Find the shortest path to center node - query = Query(""" + query = """ UNWIND $node_uuids AS node_uuid MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid}) RETURN count(*) AS score, n.uuid AS uuid - """) - + """ results, _, _ = await driver.execute_query( query, node_uuids=sorted_uuids, @@ -998,7 +1054,7 @@ def maximal_marginal_relevance( async def get_embeddings_for_nodes( - driver: AsyncDriver, nodes: list[EntityNode] + driver: GraphDriver, nodes: list[EntityNode] ) -> dict[str, list[float]]: query: LiteralString = """MATCH (n:Entity) WHERE n.uuid IN $node_uuids @@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes( async def get_embeddings_for_communities( - driver: AsyncDriver, communities: list[CommunityNode] + driver: GraphDriver, communities: list[CommunityNode] ) -> dict[str, list[float]]: query: LiteralString = """MATCH (c:Community) WHERE c.uuid IN $community_uuids @@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities( async def get_embeddings_for_edges( - driver: AsyncDriver, edges: list[EntityEdge] + driver: GraphDriver, edges: list[EntityEdge] ) -> dict[str, list[float]]: query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) WHERE e.uuid IN $edge_uuids diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index a4fe0651..b5b1c598 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -20,22 +20,24 @@ from collections import defaultdict from datetime import datetime from math import ceil -from neo4j import AsyncDriver, AsyncManagedTransaction from numpy import dot, sqrt from pydantic import BaseModel from typing_extensions import Any +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient +from graphiti_core.graph_queries import ( + get_entity_edge_save_bulk_query, + get_entity_node_save_bulk_query, +) from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.models.edges.edge_db_queries import ( - ENTITY_EDGE_SAVE_BULK, EPISODIC_EDGE_SAVE_BULK, ) from graphiti_core.models.nodes.node_db_queries import ( - ENTITY_NODE_SAVE_BULK, EPISODIC_NODE_SAVE_BULK, ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode @@ -73,7 +75,7 @@ class RawEpisode(BaseModel): async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: GraphDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await semaphore_gather( *[ @@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk( async def add_nodes_and_edges_bulk( - driver: AsyncDriver, + driver: GraphDriver, episodic_nodes: list[EpisodicNode], episodic_edges: list[EpisodicEdge], entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], embedder: EmbedderClient, ): - async with driver.session(database=DEFAULT_DATABASE) as session: + session = driver.session(database=DEFAULT_DATABASE) + try: await session.execute_write( add_nodes_and_edges_bulk_tx, episodic_nodes, @@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk( entity_nodes, entity_edges, embedder, + driver=driver, ) + finally: + await session.close() async def add_nodes_and_edges_bulk_tx( - tx: AsyncManagedTransaction, + tx: GraphDriverSession, episodic_nodes: list[EpisodicNode], episodic_edges: list[EpisodicEdge], entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], embedder: EmbedderClient, + driver: GraphDriver, ): episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: @@ -160,11 +167,13 @@ async def add_nodes_and_edges_bulk_tx( edges.append(edge_data) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) - await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) + entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider) + await tx.run(entity_node_save_bulk, nodes=nodes) await tx.run( EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] ) - await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges) + entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider) + await tx.run(entity_edge_save_bulk, entity_edges=edges) async def extract_nodes_and_edges_bulk( @@ -211,7 +220,7 @@ async def extract_nodes_and_edges_bulk( async def dedupe_nodes_bulk( - driver: AsyncDriver, + driver: GraphDriver, llm_client: LLMClient, extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: @@ -247,7 +256,7 @@ async def dedupe_nodes_bulk( async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: # First compress edges compressed_edges = await compress_edges(llm_client, extracted_edges) diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index f35e4c39..cca3fe00 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -2,9 +2,9 @@ import asyncio import logging from collections import defaultdict -from neo4j import AsyncDriver from pydantic import BaseModel +from graphiti_core.driver.driver import GraphDriver from graphiti_core.edges import CommunityEdge from graphiti_core.embedder import EmbedderClient from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather @@ -26,7 +26,7 @@ class Neighbor(BaseModel): async def get_community_clusters( - driver: AsyncDriver, group_ids: list[str] | None + driver: GraphDriver, group_ids: list[str] | None ) -> list[list[EntityNode]]: community_clusters: list[list[EntityNode]] = [] @@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: community_candidates: dict[int, int] = defaultdict(int) for neighbor in neighbors: community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count - community_lst = [ (count, community) for community, count in community_candidates.items() ] @@ -194,7 +193,7 @@ async def build_community( async def build_communities( - driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None + driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: community_clusters = await get_community_clusters(driver, group_ids) @@ -219,7 +218,7 @@ async def build_communities( return community_nodes, community_edges -async def remove_communities(driver: AsyncDriver): +async def remove_communities(driver: GraphDriver): await driver.execute_query( """ MATCH (c:Community) @@ -230,10 +229,10 @@ async def remove_communities(driver: AsyncDriver): async def determine_entity_community( - driver: AsyncDriver, entity: EntityNode + driver: GraphDriver, entity: EntityNode ) -> tuple[CommunityNode | None, bool]: # Check if the node is already part of a community - records, _, _ = await driver.execute_query( + records, _, _ = driver.execute_query( """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) RETURN @@ -251,7 +250,7 @@ async def determine_entity_community( return get_community_node_from_record(records[0]), False # If the node has no community, add it to the mode community of surrounding entities - records, _, _ = await driver.execute_query( + records, _, _ = driver.execute_query( """ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) RETURN @@ -291,7 +290,7 @@ async def determine_entity_community( async def update_community( - driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode + driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode ): community, is_new = await determine_entity_community(driver, entity) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 9c90a8e9..171c757b 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -260,7 +260,6 @@ async def resolve_extracted_edges( driver = clients.driver llm_client = clients.llm_client embedder = clients.embedder - await create_entity_edge_embeddings(embedder, extracted_edges) search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather( diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 60f402a7..de269197 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -17,9 +17,10 @@ limitations under the License. import logging from datetime import datetime, timezone -from neo4j import AsyncDriver from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphDriver +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.nodes import EpisodeType, EpisodicNode @@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3 logger = logging.getLogger(__name__) -async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False): +async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False): if delete_existing: records, _, _ = await driver.execute_query( """ @@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo for name in index_names ] ) + range_indices: list[LiteralString] = get_range_indices(driver.provider) - range_indices: list[LiteralString] = [ - 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', - 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', - 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', - 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', - 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', - 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', - 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', - 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', - 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', - 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', - 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', - 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', - 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', - 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', - 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', - 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', - 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', - 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', - 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', - ] - - fulltext_indices: list[LiteralString] = [ - """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS - FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", - """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS - FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", - """CREATE FULLTEXT INDEX community_name IF NOT EXISTS - FOR (n:Community) ON EACH [n.name, n.group_id]""", - """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS - FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", - ] + fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider) index_queries: list[LiteralString] = range_indices + fulltext_indices @@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo ) -async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None): +async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async with driver.session(database=DEFAULT_DATABASE) as session: async def delete_all(tx): @@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None): async def retrieve_episodes( - driver: AsyncDriver, + driver: GraphDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, group_ids: list[str] | None = None, @@ -123,7 +94,7 @@ async def retrieve_episodes( Retrieve the last n episodic nodes from the graph. Args: - driver (AsyncDriver): The Neo4j driver instance. + driver (Driver): The Neo4j driver instance. reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp less than or equal to this reference_time will be retrieved. This allows for querying the graph's state at a specific point in time. @@ -140,8 +111,8 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time + """ + group_id_filter + source_filter + """ @@ -157,8 +128,7 @@ async def retrieve_episodes( LIMIT $num_episodes """ ) - - result = await driver.execute_query( + result, _, _ = await driver.execute_query( query, reference_time=reference_time, source=source.name if source is not None else None, @@ -166,6 +136,7 @@ async def retrieve_episodes( group_ids=group_ids, database_=DEFAULT_DATABASE, ) + episodes = [ EpisodicNode( content=record['content'], @@ -179,6 +150,6 @@ async def retrieve_episodes( name=record['name'], source_description=record['source_description'], ) - for record in result.records + for record in result ] return list(reversed(episodes)) # Return in chronological order diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index ac572765..55e5d515 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -326,7 +326,6 @@ async def extract_attributes_from_nodes( ) -> list[EntityNode]: llm_client = clients.llm_client embedder = clients.embedder - updated_nodes: list[EntityNode] = await semaphore_gather( *[ extract_attributes_from_node( diff --git a/poetry.lock b/poetry.lock index f231bc75..61fe2a4c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -332,12 +332,12 @@ version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" -groups = ["dev"] -markers = "python_version < \"3.11\"" +groups = ["main", "dev"] files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, ] +markers = {main = "python_full_version < \"3.11.3\"", dev = "python_version == \"3.10\""} [[package]] name = "attrs" @@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -798,6 +798,20 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] +[[package]] +name = "falkordb" +version = "1.1.2" +description = "Python client for interacting with FalkorDB database" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main"] +files = [ + {file = "falkordb-1.1.2.tar.gz", hash = "sha256:db76c97efe14a56c3d65c61b966a42b874e1c78a8fb6808de3f61f4314b04023"}, +] + +[package.dependencies] +redis = ">=5.0.1,<6.0.0" + [[package]] name = "fastjsonschema" version = "2.21.1" @@ -2665,7 +2679,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"}, {file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"}, @@ -3691,6 +3704,25 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "redis" +version = "5.2.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"}, + {file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.36.2" @@ -4498,7 +4530,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -5356,4 +5388,4 @@ groq = ["groq"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "814d067fd2959bfe2db58a22637d86580b66d96f34c433852c67d02089d750ab" +content-hash = "2e02a10a6493f7564b86d5d0d09b4cf718004808e115af39550b9ee87c296fb4" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 71cb278f..58a54519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "tenacity>=9.0.0", "numpy>=1.0.0", "python-dotenv>=1.0.1", + "falkordb (>=1.1.2,<2.0.0)", ] [project.urls]