Add support for falkordb (#575)
* [wip] add support for falkordb * updates * fix-async * progress * fix-issues * rm-date-handler * red-code * rm-uns-try * fix-exm * rm-un-lines * fix-comments * fix-se-utils * fix-falkor-readme * fix-falkor-cosine-score * update-falkor-ver * fix-vec-sim * min-updates * make format * update graph driver abstraction * poetry lock * updates * linter * Update graphiti_core/search/search_utils.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: Dudi Zimberknopf <zimber.dudi@gmail.com> Co-authored-by: Gal Shubeli <galshubeli93@gmail.com> Co-authored-by: Gal Shubeli <124919062+galshubeli@users.noreply.github.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d7e1a4b79
commit
14146dc46f
27 changed files with 1131 additions and 348 deletions
|
|
@ -1,8 +1,17 @@
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
# Neo4j database connection
|
||||||
NEO4J_URI=
|
NEO4J_URI=
|
||||||
NEO4J_PORT=
|
NEO4J_PORT=
|
||||||
NEO4J_USER=
|
NEO4J_USER=
|
||||||
NEO4J_PASSWORD=
|
NEO4J_PASSWORD=
|
||||||
|
|
||||||
|
# FalkorDB database connection
|
||||||
|
FALKORDB_URI=
|
||||||
|
FALKORDB_PORT=
|
||||||
|
FALKORDB_USER=
|
||||||
|
FALKORDB_PASSWORD=
|
||||||
|
|
||||||
DEFAULT_DATABASE=
|
DEFAULT_DATABASE=
|
||||||
USE_PARALLEL_RUNTIME=
|
USE_PARALLEL_RUNTIME=
|
||||||
SEMAPHORE_LIMIT=
|
SEMAPHORE_LIMIT=
|
||||||
|
|
|
||||||
|
|
@ -64,10 +64,11 @@ Once you've found an issue tagged with "good first issue" or "help wanted," or p
|
||||||
export TEST_OPENAI_API_KEY=...
|
export TEST_OPENAI_API_KEY=...
|
||||||
export TEST_OPENAI_MODEL=...
|
export TEST_OPENAI_MODEL=...
|
||||||
export TEST_ANTHROPIC_API_KEY=...
|
export TEST_ANTHROPIC_API_KEY=...
|
||||||
|
|
||||||
export NEO4J_URI=neo4j://...
|
# For Neo4j
|
||||||
export NEO4J_USER=...
|
export TEST_URI=neo4j://...
|
||||||
export NEO4J_PASSWORD=...
|
export TEST_USER=...
|
||||||
|
export TEST_PASSWORD=...
|
||||||
```
|
```
|
||||||
|
|
||||||
## Making Changes
|
## Making Changes
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
||||||
Requirements:
|
Requirements:
|
||||||
|
|
||||||
- Python 3.10 or higher
|
- Python 3.10 or higher
|
||||||
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
||||||
- OpenAI API key (for LLM inference and embedding)
|
- OpenAI API key (for LLM inference and embedding)
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
|
|
|
||||||
|
|
@ -76,9 +76,7 @@ async def main():
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
||||||
for i, message in enumerate(messages[3:14]):
|
for i, message in enumerate(messages[3:14]):
|
||||||
episodes = await client.retrieve_episodes(
|
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
|
||||||
message.actual_timestamp, 3, group_ids=['podcast']
|
|
||||||
)
|
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
This example demonstrates the basic functionality of Graphiti, including:
|
This example demonstrates the basic functionality of Graphiti, including:
|
||||||
|
|
||||||
1. Connecting to a Neo4j database
|
1. Connecting to a Neo4j or FalkorDB database
|
||||||
2. Initializing Graphiti indices and constraints
|
2. Initializing Graphiti indices and constraints
|
||||||
3. Adding episodes to the graph
|
3. Adding episodes to the graph
|
||||||
4. Searching the graph with semantic and keyword matching
|
4. Searching the graph with semantic and keyword matching
|
||||||
|
|
@ -11,10 +11,14 @@ This example demonstrates the basic functionality of Graphiti, including:
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- Neo4j Desktop installed and running
|
- Python 3.9+
|
||||||
- A local DBMS created and started in Neo4j Desktop
|
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
|
||||||
- Python 3.9+
|
- **For Neo4j**:
|
||||||
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
|
- Neo4j Desktop installed and running
|
||||||
|
- A local DBMS created and started in Neo4j Desktop
|
||||||
|
- **For FalkorDB**:
|
||||||
|
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||||
|
|
||||||
|
|
||||||
## Setup Instructions
|
## Setup Instructions
|
||||||
|
|
||||||
|
|
@ -34,17 +38,23 @@ export OPENAI_API_KEY=your_openai_api_key
|
||||||
export NEO4J_URI=bolt://localhost:7687
|
export NEO4J_URI=bolt://localhost:7687
|
||||||
export NEO4J_USER=neo4j
|
export NEO4J_USER=neo4j
|
||||||
export NEO4J_PASSWORD=password
|
export NEO4J_PASSWORD=password
|
||||||
|
|
||||||
|
# Optional FalkorDB connection parameters (defaults shown)
|
||||||
|
export FALKORDB_URI=falkor://localhost:6379
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the example:
|
3. Run the example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python quickstart.py
|
python quickstart_neo4j.py
|
||||||
|
|
||||||
|
# For FalkorDB
|
||||||
|
python quickstart_falkordb.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## What This Example Demonstrates
|
## What This Example Demonstrates
|
||||||
|
|
||||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j
|
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
|
||||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||||
|
|
|
||||||
240
examples/quickstart/quickstart_falkordb.py
Normal file
240
examples/quickstart/quickstart_falkordb.py
Normal file
|
|
@ -0,0 +1,240 @@
|
||||||
|
"""
|
||||||
|
Copyright 2025, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from logging import INFO
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.nodes import EpisodeType
|
||||||
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# CONFIGURATION
|
||||||
|
#################################################
|
||||||
|
# Set up logging and environment variables for
|
||||||
|
# connecting to FalkorDB database
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# FalkorDB connection parameters
|
||||||
|
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
|
||||||
|
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
|
||||||
|
|
||||||
|
if not falkor_uri:
|
||||||
|
raise ValueError('FALKORDB_URI must be set')
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
#################################################
|
||||||
|
# INITIALIZATION
|
||||||
|
#################################################
|
||||||
|
# Connect to FalkorDB and set up Graphiti indices
|
||||||
|
# This is required before using other Graphiti
|
||||||
|
# functionality
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Initialize Graphiti with FalkorDB connection
|
||||||
|
graphiti = Graphiti(falkor_uri)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||||
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# ADDING EPISODES
|
||||||
|
#################################################
|
||||||
|
# Episodes are the primary units of information
|
||||||
|
# in Graphiti. They can be text or structured JSON
|
||||||
|
# and are automatically processed to extract entities
|
||||||
|
# and relationships.
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Example: Add Episodes
|
||||||
|
# Episodes list containing both text and JSON episodes
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||||
|
'the district attorney for San Francisco.',
|
||||||
|
'type': EpisodeType.text,
|
||||||
|
'description': 'podcast transcript',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||||
|
'type': EpisodeType.text,
|
||||||
|
'description': 'podcast transcript',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'content': {
|
||||||
|
'name': 'Gavin Newsom',
|
||||||
|
'position': 'Governor',
|
||||||
|
'state': 'California',
|
||||||
|
'previous_role': 'Lieutenant Governor',
|
||||||
|
'previous_location': 'San Francisco',
|
||||||
|
},
|
||||||
|
'type': EpisodeType.json,
|
||||||
|
'description': 'podcast metadata',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'content': {
|
||||||
|
'name': 'Gavin Newsom',
|
||||||
|
'position': 'Governor',
|
||||||
|
'term_start': 'January 7, 2019',
|
||||||
|
'term_end': 'Present',
|
||||||
|
},
|
||||||
|
'type': EpisodeType.json,
|
||||||
|
'description': 'podcast metadata',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add episodes to the graph
|
||||||
|
for i, episode in enumerate(episodes):
|
||||||
|
await graphiti.add_episode(
|
||||||
|
name=f'Freakonomics Radio {i}',
|
||||||
|
episode_body=episode['content']
|
||||||
|
if isinstance(episode['content'], str)
|
||||||
|
else json.dumps(episode['content']),
|
||||||
|
source=episode['type'],
|
||||||
|
source_description=episode['description'],
|
||||||
|
reference_time=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# BASIC SEARCH
|
||||||
|
#################################################
|
||||||
|
# The simplest way to retrieve relationships (edges)
|
||||||
|
# from Graphiti is using the search method, which
|
||||||
|
# performs a hybrid search combining semantic
|
||||||
|
# similarity and BM25 text retrieval.
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||||
|
print("\nSearching for: 'Who was the California Attorney General?'")
|
||||||
|
results = await graphiti.search('Who was the California Attorney General?')
|
||||||
|
|
||||||
|
# Print search results
|
||||||
|
print('\nSearch Results:')
|
||||||
|
for result in results:
|
||||||
|
print(f'UUID: {result.uuid}')
|
||||||
|
print(f'Fact: {result.fact}')
|
||||||
|
if hasattr(result, 'valid_at') and result.valid_at:
|
||||||
|
print(f'Valid from: {result.valid_at}')
|
||||||
|
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||||
|
print(f'Valid until: {result.invalid_at}')
|
||||||
|
print('---')
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# CENTER NODE SEARCH
|
||||||
|
#################################################
|
||||||
|
# For more contextually relevant results, you can
|
||||||
|
# use a center node to rerank search results based
|
||||||
|
# on their graph distance to a specific node
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Use the top search result's UUID as the center node for reranking
|
||||||
|
if results and len(results) > 0:
|
||||||
|
# Get the source node UUID from the top result
|
||||||
|
center_node_uuid = results[0].source_node_uuid
|
||||||
|
|
||||||
|
print('\nReranking search results based on graph distance:')
|
||||||
|
print(f'Using center node UUID: {center_node_uuid}')
|
||||||
|
|
||||||
|
reranked_results = await graphiti.search(
|
||||||
|
'Who was the California Attorney General?', center_node_uuid=center_node_uuid
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print reranked search results
|
||||||
|
print('\nReranked Search Results:')
|
||||||
|
for result in reranked_results:
|
||||||
|
print(f'UUID: {result.uuid}')
|
||||||
|
print(f'Fact: {result.fact}')
|
||||||
|
if hasattr(result, 'valid_at') and result.valid_at:
|
||||||
|
print(f'Valid from: {result.valid_at}')
|
||||||
|
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||||
|
print(f'Valid until: {result.invalid_at}')
|
||||||
|
print('---')
|
||||||
|
else:
|
||||||
|
print('No results found in the initial search to use as center node.')
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# NODE SEARCH USING SEARCH RECIPES
|
||||||
|
#################################################
|
||||||
|
# Graphiti provides predefined search recipes
|
||||||
|
# optimized for different search scenarios.
|
||||||
|
# Here we use NODE_HYBRID_SEARCH_RRF for retrieving
|
||||||
|
# nodes directly instead of edges.
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Example: Perform a node search using _search method with standard recipes
|
||||||
|
print(
|
||||||
|
'\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a predefined search configuration recipe and modify its limit
|
||||||
|
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
||||||
|
node_search_config.limit = 5 # Limit to 5 results
|
||||||
|
|
||||||
|
# Execute the node search
|
||||||
|
node_search_results = await graphiti._search(
|
||||||
|
query='California Governor',
|
||||||
|
config=node_search_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print node search results
|
||||||
|
print('\nNode Search Results:')
|
||||||
|
for node in node_search_results.nodes:
|
||||||
|
print(f'Node UUID: {node.uuid}')
|
||||||
|
print(f'Node Name: {node.name}')
|
||||||
|
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
|
||||||
|
print(f'Content Summary: {node_summary}')
|
||||||
|
print(f'Node Labels: {", ".join(node.labels)}')
|
||||||
|
print(f'Created At: {node.created_at}')
|
||||||
|
if hasattr(node, 'attributes') and node.attributes:
|
||||||
|
print('Attributes:')
|
||||||
|
for key, value in node.attributes.items():
|
||||||
|
print(f' {key}: {value}')
|
||||||
|
print('---')
|
||||||
|
|
||||||
|
finally:
|
||||||
|
#################################################
|
||||||
|
# CLEANUP
|
||||||
|
#################################################
|
||||||
|
# Always close the connection to FalkorDB when
|
||||||
|
# finished to properly release resources
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
await graphiti.close()
|
||||||
|
print('\nConnection closed')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
17
graphiti_core/driver/__init__.py
Normal file
17
graphiti_core/driver/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
|
||||||
81
graphiti_core/driver/driver.py
Normal file
81
graphiti_core/driver/driver.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphDriverSession(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, query: str, **kwargs: Any) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class GraphDriver(ABC):
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def session(self, database: str) -> GraphDriverSession:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# class GraphDriver:
|
||||||
|
# _driver: GraphClient
|
||||||
|
#
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# uri: str,
|
||||||
|
# user: str,
|
||||||
|
# password: str,
|
||||||
|
# ):
|
||||||
|
# if uri.startswith('falkor'):
|
||||||
|
# # FalkorDB
|
||||||
|
# self._driver = FalkorClient(uri, user, password)
|
||||||
|
# self.provider = 'falkordb'
|
||||||
|
# else:
|
||||||
|
# # Neo4j
|
||||||
|
# self._driver = Neo4jClient(uri, user, password)
|
||||||
|
# self.provider = 'neo4j'
|
||||||
|
#
|
||||||
|
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
|
||||||
|
# return self._driver.execute_query(cypher_query_, **kwargs)
|
||||||
|
#
|
||||||
|
# async def close(self):
|
||||||
|
# return await self._driver.close()
|
||||||
|
#
|
||||||
|
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||||
|
# return self._driver.delete_all_indexes(database_)
|
||||||
|
#
|
||||||
|
# def session(self, database: str) -> GraphClientSession:
|
||||||
|
# return self._driver.session(database)
|
||||||
132
graphiti_core/driver/falkordb_driver.py
Normal file
132
graphiti_core/driver/falkordb_driver.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Coroutine
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from falkordb import Graph as FalkorGraph
|
||||||
|
from falkordb.asyncio import FalkorDB
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FalkorClientSession(GraphDriverSession):
|
||||||
|
def __init__(self, graph: FalkorGraph):
|
||||||
|
self.graph = graph
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
# No cleanup needed for Falkor, but method must exist
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
# No explicit close needed for FalkorDB, but method must exist
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_write(self, func, *args, **kwargs):
|
||||||
|
# Directly await the provided async function with `self` as the transaction/session
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any:
|
||||||
|
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
|
||||||
|
if isinstance(cypher_query_, list):
|
||||||
|
for cypher, params in cypher_query_:
|
||||||
|
params = convert_datetimes_to_strings(params)
|
||||||
|
await self.graph.query(str(cypher), params)
|
||||||
|
else:
|
||||||
|
params = dict(kwargs)
|
||||||
|
params = convert_datetimes_to_strings(params)
|
||||||
|
await self.graph.query(str(cypher_query_), params)
|
||||||
|
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FalkorDriver(GraphDriver):
|
||||||
|
provider: str = 'falkordb'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
user: str,
|
||||||
|
password: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if user and password:
|
||||||
|
uri_parts = uri.split('://', 1)
|
||||||
|
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
||||||
|
|
||||||
|
self.client = FalkorDB.from_url(
|
||||||
|
url=uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_graph(self, graph_name: str) -> FalkorGraph:
|
||||||
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
||||||
|
if graph_name is None:
|
||||||
|
graph_name = 'DEFAULT_DATABASE'
|
||||||
|
return self.client.select_graph(graph_name)
|
||||||
|
|
||||||
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
||||||
|
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
|
||||||
|
graph = self._get_graph(graph_name)
|
||||||
|
|
||||||
|
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
||||||
|
params = convert_datetimes_to_strings(dict(kwargs))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await graph.query(cypher_query_, params)
|
||||||
|
except Exception as e:
|
||||||
|
if 'already indexed' in str(e):
|
||||||
|
# check if index already exists
|
||||||
|
logger.info(f'Index already exists: {e}')
|
||||||
|
return None
|
||||||
|
logger.error(f'Error executing FalkorDB query: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Convert the result header to a list of strings
|
||||||
|
header = [h[1].decode('utf-8') for h in result.header]
|
||||||
|
return result.result_set, header, None
|
||||||
|
|
||||||
|
def session(self, database: str) -> GraphDriverSession:
|
||||||
|
return FalkorClientSession(self._get_graph(database))
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self.client.connection.close()
|
||||||
|
|
||||||
|
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||||
|
return self.execute_query(
|
||||||
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
|
database_=database_,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_datetimes_to_strings(obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [convert_datetimes_to_strings(item) for item in obj]
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
||||||
|
elif isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
60
graphiti_core/driver/neo4j_driver.py
Normal file
60
graphiti_core/driver/neo4j_driver.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Coroutine
|
||||||
|
from typing import Any, LiteralString
|
||||||
|
|
||||||
|
from neo4j import AsyncGraphDatabase
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jDriver(GraphDriver):
|
||||||
|
provider: str = 'neo4j'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
user: str,
|
||||||
|
password: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.client = AsyncGraphDatabase.driver(
|
||||||
|
uri=uri,
|
||||||
|
auth=(user, password),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
|
||||||
|
params = kwargs.pop('params', None)
|
||||||
|
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def session(self, database: str) -> GraphDriverSession:
|
||||||
|
return self.client.session(database=database) # type: ignore
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
return await self.client.close()
|
||||||
|
|
||||||
|
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||||
|
return self.client.execute_query(
|
||||||
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
|
database_=database_,
|
||||||
|
)
|
||||||
|
|
@ -21,10 +21,10 @@ from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||||
|
|
@ -62,9 +62,9 @@ class Edge(BaseModel, ABC):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: GraphDriver): ...
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||||
|
|
@ -87,11 +87,11 @@ class Edge(BaseModel, ABC):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||||
|
|
||||||
|
|
||||||
class EpisodicEdge(Edge):
|
class EpisodicEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
EPISODIC_EDGE_SAVE,
|
EPISODIC_EDGE_SAVE,
|
||||||
episode_uuid=self.source_node_uuid,
|
episode_uuid=self.source_node_uuid,
|
||||||
|
|
@ -102,12 +102,12 @@ class EpisodicEdge(Edge):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||||
|
|
@ -130,7 +130,7 @@ class EpisodicEdge(Edge):
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
|
|
@ -156,7 +156,7 @@ class EpisodicEdge(Edge):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -226,7 +226,7 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
return self.fact_embedding
|
return self.fact_embedding
|
||||||
|
|
||||||
async def load_fact_embedding(self, driver: AsyncDriver):
|
async def load_fact_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
query: LiteralString = """
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN e.fact_embedding AS fact_embedding
|
RETURN e.fact_embedding AS fact_embedding
|
||||||
|
|
@ -240,7 +240,7 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
self.fact_embedding = records[0]['fact_embedding']
|
self.fact_embedding = records[0]['fact_embedding']
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
edge_data: dict[str, Any] = {
|
edge_data: dict[str, Any] = {
|
||||||
'source_uuid': self.source_node_uuid,
|
'source_uuid': self.source_node_uuid,
|
||||||
'target_uuid': self.target_node_uuid,
|
'target_uuid': self.target_node_uuid,
|
||||||
|
|
@ -264,12 +264,12 @@ class EntityEdge(Edge):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
|
|
@ -287,7 +287,7 @@ class EntityEdge(Edge):
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
if len(uuids) == 0:
|
if len(uuids) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -309,7 +309,7 @@ class EntityEdge(Edge):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -342,11 +342,11 @@ class EntityEdge(Edge):
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -359,7 +359,7 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
|
|
||||||
class CommunityEdge(Edge):
|
class CommunityEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
COMMUNITY_EDGE_SAVE,
|
COMMUNITY_EDGE_SAVE,
|
||||||
community_uuid=self.source_node_uuid,
|
community_uuid=self.source_node_uuid,
|
||||||
|
|
@ -370,12 +370,12 @@ class CommunityEdge(Edge):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
||||||
|
|
@ -396,7 +396,7 @@ class CommunityEdge(Edge):
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||||
|
|
@ -420,7 +420,7 @@ class CommunityEdge(Edge):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
source_node_uuid=record['source_node_uuid'],
|
source_node_uuid=record['source_node_uuid'],
|
||||||
target_node_uuid=record['target_node_uuid'],
|
target_node_uuid=record['target_node_uuid'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=parse_db_date(record['created_at']),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
episodes=record['episodes'],
|
episodes=record['episodes'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=parse_db_date(record['created_at']),
|
||||||
expired_at=parse_db_date(record['expired_at']),
|
expired_at=parse_db_date(record['expired_at']),
|
||||||
valid_at=parse_db_date(record['valid_at']),
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
invalid_at=parse_db_date(record['invalid_at']),
|
invalid_at=parse_db_date(record['invalid_at']),
|
||||||
|
|
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
source_node_uuid=record['source_node_uuid'],
|
source_node_uuid=record['source_node_uuid'],
|
||||||
target_node_uuid=record['target_node_uuid'],
|
target_node_uuid=record['target_node_uuid'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=parse_db_date(record['created_at']),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
147
graphiti_core/graph_queries.py
Normal file
147
graphiti_core/graph_queries.py
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""
|
||||||
|
Database query utilities for different graph database backends.
|
||||||
|
|
||||||
|
This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
||||||
|
supporting index creation, fulltext search, and bulk operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
|
ENTITY_EDGE_SAVE_BULK,
|
||||||
|
)
|
||||||
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
|
ENTITY_NODE_SAVE_BULK,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
||||||
|
NEO4J_TO_FALKORDB_MAPPING = {
|
||||||
|
'node_name_and_summary': 'Entity',
|
||||||
|
'community_name': 'Community',
|
||||||
|
'episode_content': 'Episodic',
|
||||||
|
'edge_name_and_fact': 'RELATES_TO',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
return [
|
||||||
|
# Entity node
|
||||||
|
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
||||||
|
# Episodic node
|
||||||
|
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
||||||
|
# Community node
|
||||||
|
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
||||||
|
# RELATES_TO edge
|
||||||
|
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
||||||
|
# MENTIONS edge
|
||||||
|
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
||||||
|
# HAS_MEMBER edge
|
||||||
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||||
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||||
|
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||||
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||||
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||||
|
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||||
|
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||||
|
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||||
|
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||||
|
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||||
|
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||||
|
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||||
|
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||||
|
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||||
|
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||||
|
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||||
|
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||||
|
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||||
|
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
return [
|
||||||
|
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
||||||
|
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
||||||
|
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
||||||
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||||
|
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||||
|
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||||
|
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||||
|
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||||
|
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||||
|
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||||
|
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||||
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||||
|
else:
|
||||||
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||||
|
|
||||||
|
|
||||||
|
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
||||||
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
||||||
|
else:
|
||||||
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||||
|
|
||||||
|
|
||||||
|
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||||
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||||
|
else:
|
||||||
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
queries = []
|
||||||
|
for node in nodes:
|
||||||
|
for label in node['labels']:
|
||||||
|
queries.append(
|
||||||
|
(
|
||||||
|
f"""
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (n:Entity {{uuid: node.uuid}})
|
||||||
|
SET n:{label}
|
||||||
|
SET n = node
|
||||||
|
WITH n, node
|
||||||
|
SET n.name_embedding = vecf32(node.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
""",
|
||||||
|
{'nodes': [node]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
else:
|
||||||
|
return ENTITY_NODE_SAVE_BULK
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
||||||
|
if db_type == 'falkordb':
|
||||||
|
return """
|
||||||
|
UNWIND $entity_edges AS edge
|
||||||
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||||
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||||
|
WITH r, edge
|
||||||
|
RETURN edge.uuid AS uuid"""
|
||||||
|
else:
|
||||||
|
return ENTITY_EDGE_SAVE_BULK
|
||||||
|
|
@ -19,12 +19,13 @@ from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from neo4j import AsyncGraphDatabase
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
|
|
@ -94,12 +95,13 @@ class Graphiti:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
uri: str,
|
||||||
user: str,
|
user: str = None,
|
||||||
password: str,
|
password: str = None,
|
||||||
llm_client: LLMClient | None = None,
|
llm_client: LLMClient | None = None,
|
||||||
embedder: EmbedderClient | None = None,
|
embedder: EmbedderClient | None = None,
|
||||||
cross_encoder: CrossEncoderClient | None = None,
|
cross_encoder: CrossEncoderClient | None = None,
|
||||||
store_raw_episode_content: bool = True,
|
store_raw_episode_content: bool = True,
|
||||||
|
graph_driver: GraphDriver = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a Graphiti instance.
|
Initialize a Graphiti instance.
|
||||||
|
|
@ -137,7 +139,9 @@ class Graphiti:
|
||||||
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
||||||
Graphiti if you're using the default OpenAIClient.
|
Graphiti if you're using the default OpenAIClient.
|
||||||
"""
|
"""
|
||||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
||||||
|
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
||||||
|
|
||||||
self.database = DEFAULT_DATABASE
|
self.database = DEFAULT_DATABASE
|
||||||
self.store_raw_episode_content = store_raw_episode_content
|
self.store_raw_episode_content = store_raw_episode_content
|
||||||
if llm_client:
|
if llm_client:
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from graphiti_core.cross_encoder import CrossEncoderClient
|
from graphiti_core.cross_encoder import CrossEncoderClient
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
||||||
|
|
||||||
class GraphitiClients(BaseModel):
|
class GraphitiClients(BaseModel):
|
||||||
driver: AsyncDriver
|
driver: GraphDriver
|
||||||
llm_client: LLMClient
|
llm_client: LLMClient
|
||||||
embedder: EmbedderClient
|
embedder: EmbedderClient
|
||||||
cross_encoder: CrossEncoderClient
|
cross_encoder: CrossEncoderClient
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
||||||
return neo_date.to_native() if neo_date else None
|
return (
|
||||||
|
neo_date.to_native()
|
||||||
|
if isinstance(neo_date, neo4j_time.DateTime)
|
||||||
|
else datetime.fromisoformat(neo_date)
|
||||||
|
if neo_date
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def lucene_sanitize(query: str) -> str:
|
def lucene_sanitize(query: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,19 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
|
||||||
|
|
@ -22,13 +22,13 @@ from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
COMMUNITY_NODE_SAVE,
|
COMMUNITY_NODE_SAVE,
|
||||||
ENTITY_NODE_SAVE,
|
ENTITY_NODE_SAVE,
|
||||||
|
|
@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
|
||||||
created_at: datetime = Field(default_factory=lambda: utc_now())
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: GraphDriver): ...
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||||
|
|
@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||||
|
|
@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
|
||||||
return 'SUCCESS'
|
return 'SUCCESS'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
||||||
|
|
||||||
|
|
||||||
class EpisodicNode(Node):
|
class EpisodicNode(Node):
|
||||||
|
|
@ -150,7 +150,7 @@ class EpisodicNode(Node):
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
EPISODIC_NODE_SAVE,
|
EPISODIC_NODE_SAVE,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
|
@ -165,12 +165,12 @@ class EpisodicNode(Node):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic {uuid: $uuid})
|
MATCH (e:Episodic {uuid: $uuid})
|
||||||
|
|
@ -197,7 +197,7 @@ class EpisodicNode(Node):
|
||||||
return episodes[0]
|
return episodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||||
|
|
@ -224,7 +224,7 @@ class EpisodicNode(Node):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -263,7 +263,7 @@ class EpisodicNode(Node):
|
||||||
return episodes
|
return episodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
|
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||||
|
|
@ -304,7 +304,7 @@ class EntityNode(Node):
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: AsyncDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
query: LiteralString = """
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
RETURN n.name_embedding AS name_embedding
|
RETURN n.name_embedding AS name_embedding
|
||||||
|
|
@ -318,7 +318,7 @@ class EntityNode(Node):
|
||||||
|
|
||||||
self.name_embedding = records[0]['name_embedding']
|
self.name_embedding = records[0]['name_embedding']
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
entity_data: dict[str, Any] = {
|
entity_data: dict[str, Any] = {
|
||||||
'uuid': self.uuid,
|
'uuid': self.uuid,
|
||||||
'name': self.name,
|
'name': self.name,
|
||||||
|
|
@ -337,16 +337,16 @@ class EntityNode(Node):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
"""
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -364,7 +364,7 @@ class EntityNode(Node):
|
||||||
return nodes[0]
|
return nodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||||
|
|
@ -382,7 +382,7 @@ class EntityNode(Node):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -416,7 +416,7 @@ class CommunityNode(Node):
|
||||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
COMMUNITY_NODE_SAVE,
|
COMMUNITY_NODE_SAVE,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
|
@ -428,7 +428,7 @@ class CommunityNode(Node):
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -441,7 +441,7 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: AsyncDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
query: LiteralString = """
|
||||||
MATCH (c:Community {uuid: $uuid})
|
MATCH (c:Community {uuid: $uuid})
|
||||||
RETURN c.name_embedding AS name_embedding
|
RETURN c.name_embedding AS name_embedding
|
||||||
|
|
@ -456,7 +456,7 @@ class CommunityNode(Node):
|
||||||
self.name_embedding = records[0]['name_embedding']
|
self.name_embedding = records[0]['name_embedding']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community {uuid: $uuid})
|
MATCH (n:Community {uuid: $uuid})
|
||||||
|
|
@ -480,7 +480,7 @@ class CommunityNode(Node):
|
||||||
return nodes[0]
|
return nodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community) WHERE n.uuid IN $uuids
|
MATCH (n:Community) WHERE n.uuid IN $uuids
|
||||||
|
|
@ -503,7 +503,7 @@ class CommunityNode(Node):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_group_ids(
|
async def get_by_group_ids(
|
||||||
cls,
|
cls,
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
group_ids: list[str],
|
group_ids: list[str],
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
uuid_cursor: str | None = None,
|
uuid_cursor: str | None = None,
|
||||||
|
|
@ -542,8 +542,8 @@ class CommunityNode(Node):
|
||||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||||
return EpisodicNode(
|
return EpisodicNode(
|
||||||
content=record['content'],
|
content=record['content'],
|
||||||
created_at=record['created_at'].to_native().timestamp(),
|
created_at=parse_db_date(record['created_at']).timestamp(),
|
||||||
valid_at=(record['valid_at'].to_native()),
|
valid_at=(parse_db_date(record['valid_at'])),
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
source=EpisodeType.from_str(record['source']),
|
source=EpisodeType.from_str(record['source']),
|
||||||
|
|
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
labels=record['labels'],
|
labels=record['labels'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=parse_db_date(record['created_at']),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
attributes=record['attributes'],
|
attributes=record['attributes'],
|
||||||
)
|
)
|
||||||
|
|
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
name_embedding=record['name_embedding'],
|
name_embedding=record['name_embedding'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=parse_db_date(record['created_at']),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,8 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
|
|
||||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.errors import SearchRerankerError
|
from graphiti_core.errors import SearchRerankerError
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
|
|
@ -94,7 +93,7 @@ async def search(
|
||||||
)
|
)
|
||||||
|
|
||||||
# if group_ids is empty, set it to None
|
# if group_ids is empty, set it to None
|
||||||
group_ids = group_ids if group_ids else None
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
||||||
edges, nodes, episodes, communities = await semaphore_gather(
|
edges, nodes, episodes, communities = await semaphore_gather(
|
||||||
edge_search(
|
edge_search(
|
||||||
driver,
|
driver,
|
||||||
|
|
@ -160,7 +159,7 @@ async def search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_search(
|
async def edge_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
cross_encoder: CrossEncoderClient,
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
|
|
@ -174,7 +173,6 @@ async def edge_search(
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[EntityEdge]] = list(
|
search_results: list[list[EntityEdge]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -261,7 +259,7 @@ async def edge_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_search(
|
async def node_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
cross_encoder: CrossEncoderClient,
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
|
|
@ -275,7 +273,6 @@ async def node_search(
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[EntityNode]] = list(
|
search_results: list[list[EntityNode]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -344,7 +341,7 @@ async def node_search(
|
||||||
|
|
||||||
|
|
||||||
async def episode_search(
|
async def episode_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
cross_encoder: CrossEncoderClient,
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
_query_vector: list[float],
|
_query_vector: list[float],
|
||||||
|
|
@ -356,7 +353,6 @@ async def episode_search(
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_results: list[list[EpisodicNode]] = list(
|
search_results: list[list[EpisodicNode]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -392,7 +388,7 @@ async def episode_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_search(
|
async def community_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
cross_encoder: CrossEncoderClient,
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,16 @@ from time import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
|
||||||
from numpy._typing import NDArray
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
|
from graphiti_core.graph_queries import (
|
||||||
|
get_nodes_query,
|
||||||
|
get_relationships_query,
|
||||||
|
get_vector_cosine_func_query,
|
||||||
|
)
|
||||||
from graphiti_core.helpers import (
|
from graphiti_core.helpers import (
|
||||||
DEFAULT_DATABASE,
|
DEFAULT_DATABASE,
|
||||||
RUNTIME_QUERY,
|
RUNTIME_QUERY,
|
||||||
|
|
@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
group_ids_filter_list = (
|
group_ids_filter_list = (
|
||||||
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
|
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
||||||
)
|
)
|
||||||
group_ids_filter = ''
|
group_ids_filter = ''
|
||||||
for f in group_ids_filter_list:
|
for f in group_ids_filter_list:
|
||||||
|
|
@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
|
|
||||||
async def get_episodes_by_mentions(
|
async def get_episodes_by_mentions(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
|
@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
|
||||||
|
|
||||||
|
|
||||||
async def get_mentioned_nodes(
|
async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: GraphDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
query = """
|
||||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
n.uuid As uuid,
|
n.uuid As uuid,
|
||||||
|
|
@ -106,7 +111,10 @@ async def get_mentioned_nodes(
|
||||||
n.summary AS summary,
|
n.summary AS summary,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS attributes
|
properties(n) AS attributes
|
||||||
""",
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
uuids=episode_uuids,
|
uuids=episode_uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -118,11 +126,11 @@ async def get_mentioned_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_communities_by_nodes(
|
async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: GraphDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
"""
|
query = """
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
c.uuid As uuid,
|
c.uuid As uuid,
|
||||||
|
|
@ -130,7 +138,10 @@ async def get_communities_by_nodes(
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
"""
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
uuids=node_uuids,
|
uuids=node_uuids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -142,7 +153,7 @@ async def get_communities_by_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
query: str,
|
query: str,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
|
@ -155,34 +166,35 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
cypher_query = Query(
|
query = (
|
||||||
"""
|
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query')
|
||||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
|
+ """
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
WHERE r.group_id IN $group_ids"""
|
WHERE r.group_id IN $group_ids """
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
|
+ """
|
||||||
RETURN
|
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||||
r.uuid AS uuid,
|
RETURN
|
||||||
r.group_id AS group_id,
|
r.uuid AS uuid,
|
||||||
n.uuid AS source_node_uuid,
|
r.group_id AS group_id,
|
||||||
m.uuid AS target_node_uuid,
|
n.uuid AS source_node_uuid,
|
||||||
r.created_at AS created_at,
|
m.uuid AS target_node_uuid,
|
||||||
r.name AS name,
|
r.created_at AS created_at,
|
||||||
r.fact AS fact,
|
r.name AS name,
|
||||||
r.episodes AS episodes,
|
r.fact AS fact,
|
||||||
r.expired_at AS expired_at,
|
r.episodes AS episodes,
|
||||||
r.valid_at AS valid_at,
|
r.expired_at AS expired_at,
|
||||||
r.invalid_at AS invalid_at,
|
r.valid_at AS valid_at,
|
||||||
properties(r) AS attributes
|
r.invalid_at AS invalid_at,
|
||||||
ORDER BY score DESC LIMIT $limit
|
properties(r) AS attributes
|
||||||
"""
|
ORDER BY score DESC LIMIT $limit
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
cypher_query,
|
query,
|
||||||
filter_params,
|
params=filter_params,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -196,7 +208,7 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
|
|
@ -224,36 +236,38 @@ async def edge_similarity_search(
|
||||||
if target_node_uuid is not None:
|
if target_node_uuid is not None:
|
||||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||||
|
|
||||||
query: LiteralString = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
+ """
|
||||||
WHERE score > $min_score
|
WITH DISTINCT r, """
|
||||||
RETURN
|
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
||||||
r.uuid AS uuid,
|
+ """ AS score
|
||||||
r.group_id AS group_id,
|
WHERE score > $min_score
|
||||||
startNode(r).uuid AS source_node_uuid,
|
RETURN
|
||||||
endNode(r).uuid AS target_node_uuid,
|
r.uuid AS uuid,
|
||||||
r.created_at AS created_at,
|
r.group_id AS group_id,
|
||||||
r.name AS name,
|
startNode(r).uuid AS source_node_uuid,
|
||||||
r.fact AS fact,
|
endNode(r).uuid AS target_node_uuid,
|
||||||
r.episodes AS episodes,
|
r.created_at AS created_at,
|
||||||
r.expired_at AS expired_at,
|
r.name AS name,
|
||||||
r.valid_at AS valid_at,
|
r.fact AS fact,
|
||||||
r.invalid_at AS invalid_at,
|
r.episodes AS episodes,
|
||||||
properties(r) AS attributes
|
r.expired_at AS expired_at,
|
||||||
ORDER BY score DESC
|
r.valid_at AS valid_at,
|
||||||
LIMIT $limit
|
r.invalid_at AS invalid_at,
|
||||||
|
properties(r) AS attributes
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
records, header, _ = await driver.execute_query(
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
query,
|
||||||
query_params,
|
params=query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
|
|
@ -264,13 +278,16 @@ async def edge_similarity_search(
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if driver.provider == 'falkordb':
|
||||||
|
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def edge_bfs_search(
|
async def edge_bfs_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
bfs_origin_node_uuids: list[str] | None,
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
bfs_max_depth: int,
|
bfs_max_depth: int,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
|
|
@ -282,14 +299,14 @@ async def edge_bfs_search(
|
||||||
|
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
query = Query(
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
UNWIND relationships(path) AS rel
|
UNWIND relationships(path) AS rel
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||||
WHERE r.uuid = rel.uuid
|
WHERE r.uuid = rel.uuid
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
|
|
@ -311,7 +328,7 @@ async def edge_bfs_search(
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
filter_params,
|
params=filter_params,
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
depth=bfs_max_depth,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -325,7 +342,7 @@ async def edge_bfs_search(
|
||||||
|
|
||||||
|
|
||||||
async def node_fulltext_search(
|
async def node_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
query: str,
|
query: str,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
|
@ -335,38 +352,41 @@ async def node_fulltext_search(
|
||||||
fuzzy_query = fulltext_query(query, group_ids)
|
fuzzy_query = fulltext_query(query, group_ids)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
|
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
||||||
|
+ """
|
||||||
|
YIELD node AS n, score
|
||||||
|
WITH n, score
|
||||||
|
LIMIT $limit
|
||||||
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
||||||
YIELD node AS n, score
|
|
||||||
WHERE n:Entity
|
|
||||||
"""
|
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
records, header, _ = await driver.execute_query(
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
query,
|
||||||
filter_params,
|
params=filter_params,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
if driver.provider == 'falkordb':
|
||||||
|
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def node_similarity_search(
|
async def node_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
|
@ -384,22 +404,28 @@ async def node_similarity_search(
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
query_params.update(filter_params)
|
query_params.update(filter_params)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, """
|
||||||
WHERE score > $min_score"""
|
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||||
|
+ """ AS score
|
||||||
|
WHERE score > $min_score"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
"""
|
||||||
query_params,
|
)
|
||||||
|
|
||||||
|
records, header, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
params=query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -407,13 +433,15 @@ async def node_similarity_search(
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
if driver.provider == 'falkordb':
|
||||||
|
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
async def node_bfs_search(
|
async def node_bfs_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
bfs_origin_node_uuids: list[str] | None,
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
bfs_max_depth: int,
|
bfs_max_depth: int,
|
||||||
|
|
@ -425,18 +453,21 @@ async def node_bfs_search(
|
||||||
|
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
WHERE n.group_id = origin.group_id
|
WHERE n.group_id = origin.group_id
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
"""
|
||||||
filter_params,
|
)
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
params=filter_params,
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
depth=bfs_max_depth,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -449,7 +480,7 @@ async def node_bfs_search(
|
||||||
|
|
||||||
|
|
||||||
async def episode_fulltext_search(
|
async def episode_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
query: str,
|
query: str,
|
||||||
_search_filter: SearchFilters,
|
_search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
|
@ -460,9 +491,9 @@ async def episode_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
"""
|
get_nodes_query(driver.provider, 'episode_content', '$query')
|
||||||
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
|
+ """
|
||||||
YIELD node AS episode, score
|
YIELD node AS episode, score
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
WHERE e.uuid = episode.uuid
|
WHERE e.uuid = episode.uuid
|
||||||
|
|
@ -478,7 +509,11 @@ async def episode_fulltext_search(
|
||||||
e.entity_edges AS entity_edges
|
e.entity_edges AS entity_edges
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -491,7 +526,7 @@ async def episode_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_fulltext_search(
|
async def community_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
|
@ -501,9 +536,9 @@ async def community_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
"""
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
|
+ """
|
||||||
YIELD node AS comm, score
|
YIELD node AS comm, score
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid AS uuid,
|
comm.uuid AS uuid,
|
||||||
|
|
@ -513,7 +548,11 @@ async def community_fulltext_search(
|
||||||
comm.summary AS summary
|
comm.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -526,7 +565,7 @@ async def community_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def community_similarity_search(
|
async def community_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
|
|
@ -540,14 +579,16 @@ async def community_similarity_search(
|
||||||
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
||||||
query_params['group_ids'] = group_ids
|
query_params['group_ids'] = group_ids
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
WITH comm, """
|
||||||
|
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
||||||
|
+ """ AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
|
|
@ -557,7 +598,11 @@ async def community_similarity_search(
|
||||||
comm.summary AS summary
|
comm.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -573,7 +618,7 @@ async def community_similarity_search(
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
|
|
@ -590,7 +635,7 @@ async def hybrid_node_search(
|
||||||
A list of text queries to search for.
|
A list of text queries to search for.
|
||||||
embeddings : list[list[float]]
|
embeddings : list[list[float]]
|
||||||
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
||||||
driver : AsyncDriver
|
driver : GraphDriver
|
||||||
The Neo4j driver instance for database operations.
|
The Neo4j driver instance for database operations.
|
||||||
group_ids : list[str] | None, optional
|
group_ids : list[str] | None, optional
|
||||||
The list of group ids to retrieve nodes from.
|
The list of group ids to retrieve nodes from.
|
||||||
|
|
@ -645,7 +690,7 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
|
|
@ -664,29 +709,33 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """UNWIND $nodes AS node
|
+ """
|
||||||
MATCH (n:Entity {group_id: $group_id})
|
UNWIND $nodes AS node
|
||||||
"""
|
MATCH (n:Entity {group_id: $group_id})
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
|
WITH node, n, """
|
||||||
|
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
||||||
|
+ """ AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
||||||
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
|
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
||||||
|
+ """
|
||||||
YIELD node AS m
|
YIELD node AS m
|
||||||
WHERE m.group_id = $group_id
|
WHERE m.group_id = $group_id
|
||||||
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
||||||
|
|
||||||
WITH node,
|
WITH node,
|
||||||
top_vector_nodes,
|
top_vector_nodes,
|
||||||
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
||||||
|
|
||||||
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
||||||
|
|
||||||
UNWIND combined_nodes AS combined_node
|
UNWIND combined_nodes AS combined_node
|
||||||
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
||||||
|
|
||||||
RETURN
|
RETURN
|
||||||
node.uuid AS search_node_uuid,
|
node.uuid AS search_node_uuid,
|
||||||
[x IN deduped_nodes | {
|
[x IN deduped_nodes | {
|
||||||
|
|
@ -714,7 +763,7 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
query_params,
|
params=query_params,
|
||||||
nodes=query_nodes,
|
nodes=query_nodes,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -736,7 +785,7 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
|
|
@ -752,43 +801,47 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """UNWIND $edges AS edge
|
+ """
|
||||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
UNWIND $edges AS edge
|
||||||
"""
|
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
WITH e, edge, """
|
||||||
WHERE score > $min_score
|
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
||||||
WITH edge, e, score
|
+ """ AS score
|
||||||
ORDER BY score DESC
|
WHERE score > $min_score
|
||||||
RETURN edge.uuid AS search_edge_uuid,
|
WITH edge, e, score
|
||||||
collect({
|
ORDER BY score DESC
|
||||||
uuid: e.uuid,
|
RETURN edge.uuid AS search_edge_uuid,
|
||||||
source_node_uuid: startNode(e).uuid,
|
collect({
|
||||||
target_node_uuid: endNode(e).uuid,
|
uuid: e.uuid,
|
||||||
created_at: e.created_at,
|
source_node_uuid: startNode(e).uuid,
|
||||||
name: e.name,
|
target_node_uuid: endNode(e).uuid,
|
||||||
group_id: e.group_id,
|
created_at: e.created_at,
|
||||||
fact: e.fact,
|
name: e.name,
|
||||||
fact_embedding: e.fact_embedding,
|
group_id: e.group_id,
|
||||||
episodes: e.episodes,
|
fact: e.fact,
|
||||||
expired_at: e.expired_at,
|
fact_embedding: e.fact_embedding,
|
||||||
valid_at: e.valid_at,
|
episodes: e.episodes,
|
||||||
invalid_at: e.invalid_at,
|
expired_at: e.expired_at,
|
||||||
attributes: properties(e)
|
valid_at: e.valid_at,
|
||||||
})[..$limit] AS matches
|
invalid_at: e.invalid_at,
|
||||||
|
attributes: properties(e)
|
||||||
|
})[..$limit] AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
query_params,
|
params=query_params,
|
||||||
edges=[edge.model_dump() for edge in edges],
|
edges=[edge.model_dump() for edge in edges],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
result['search_edge_uuid']: [
|
result['search_edge_uuid']: [
|
||||||
get_entity_edge_from_record(record) for record in result['matches']
|
get_entity_edge_from_record(record) for record in result['matches']
|
||||||
|
|
@ -802,7 +855,7 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
|
|
||||||
async def get_edge_invalidation_candidates(
|
async def get_edge_invalidation_candidates(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
search_filter: SearchFilters,
|
search_filter: SearchFilters,
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
|
|
@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """UNWIND $edges AS edge
|
+ """
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
UNWIND $edges AS edge
|
||||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||||
"""
|
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||||
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
WITH edge, e, """
|
||||||
WHERE score > $min_score
|
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
||||||
WITH edge, e, score
|
+ """ AS score
|
||||||
ORDER BY score DESC
|
WHERE score > $min_score
|
||||||
RETURN edge.uuid AS search_edge_uuid,
|
WITH edge, e, score
|
||||||
collect({
|
ORDER BY score DESC
|
||||||
uuid: e.uuid,
|
RETURN edge.uuid AS search_edge_uuid,
|
||||||
source_node_uuid: startNode(e).uuid,
|
collect({
|
||||||
target_node_uuid: endNode(e).uuid,
|
uuid: e.uuid,
|
||||||
created_at: e.created_at,
|
source_node_uuid: startNode(e).uuid,
|
||||||
name: e.name,
|
target_node_uuid: endNode(e).uuid,
|
||||||
group_id: e.group_id,
|
created_at: e.created_at,
|
||||||
fact: e.fact,
|
name: e.name,
|
||||||
fact_embedding: e.fact_embedding,
|
group_id: e.group_id,
|
||||||
episodes: e.episodes,
|
fact: e.fact,
|
||||||
expired_at: e.expired_at,
|
fact_embedding: e.fact_embedding,
|
||||||
valid_at: e.valid_at,
|
episodes: e.episodes,
|
||||||
invalid_at: e.invalid_at,
|
expired_at: e.expired_at,
|
||||||
attributes: properties(e)
|
valid_at: e.valid_at,
|
||||||
})[..$limit] AS matches
|
invalid_at: e.invalid_at,
|
||||||
|
attributes: properties(e)
|
||||||
|
})[..$limit] AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
query_params,
|
params=query_params,
|
||||||
edges=[edge.model_dump() for edge in edges],
|
edges=[edge.model_dump() for edge in edges],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
|
|
@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
node_uuids: list[str],
|
node_uuids: list[str],
|
||||||
center_node_uuid: str,
|
center_node_uuid: str,
|
||||||
min_score: float = 0,
|
min_score: float = 0,
|
||||||
|
|
@ -894,21 +950,22 @@ async def node_distance_reranker(
|
||||||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = """
|
||||||
UNWIND $node_uuids AS node_uuid
|
UNWIND $node_uuids AS node_uuid
|
||||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
|
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
||||||
RETURN length(p) AS score, node_uuid AS uuid
|
RETURN 1 AS score, node_uuid AS uuid
|
||||||
""")
|
"""
|
||||||
|
results, header, _ = await driver.execute_query(
|
||||||
path_results, _, _ = await driver.execute_query(
|
|
||||||
query,
|
query,
|
||||||
node_uuids=filtered_uuids,
|
node_uuids=filtered_uuids,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
if driver.provider == 'falkordb':
|
||||||
|
results = [dict(zip(header, row, strict=True)) for row in results]
|
||||||
|
|
||||||
for result in path_results:
|
for result in results:
|
||||||
uuid = result['uuid']
|
uuid = result['uuid']
|
||||||
score = result['score']
|
score = result['score']
|
||||||
scores[uuid] = score
|
scores[uuid] = score
|
||||||
|
|
@ -929,19 +986,18 @@ async def node_distance_reranker(
|
||||||
|
|
||||||
|
|
||||||
async def episode_mentions_reranker(
|
async def episode_mentions_reranker(
|
||||||
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
|
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# use rrf as a preliminary ranker
|
# use rrf as a preliminary ranker
|
||||||
sorted_uuids = rrf(node_uuids)
|
sorted_uuids = rrf(node_uuids)
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = Query("""
|
query = """
|
||||||
UNWIND $node_uuids AS node_uuid
|
UNWIND $node_uuids AS node_uuid
|
||||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||||
RETURN count(*) AS score, n.uuid AS uuid
|
RETURN count(*) AS score, n.uuid AS uuid
|
||||||
""")
|
"""
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
node_uuids=sorted_uuids,
|
node_uuids=sorted_uuids,
|
||||||
|
|
@ -998,7 +1054,7 @@ def maximal_marginal_relevance(
|
||||||
|
|
||||||
|
|
||||||
async def get_embeddings_for_nodes(
|
async def get_embeddings_for_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: GraphDriver, nodes: list[EntityNode]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (n:Entity)
|
query: LiteralString = """MATCH (n:Entity)
|
||||||
WHERE n.uuid IN $node_uuids
|
WHERE n.uuid IN $node_uuids
|
||||||
|
|
@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_embeddings_for_communities(
|
async def get_embeddings_for_communities(
|
||||||
driver: AsyncDriver, communities: list[CommunityNode]
|
driver: GraphDriver, communities: list[CommunityNode]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (c:Community)
|
query: LiteralString = """MATCH (c:Community)
|
||||||
WHERE c.uuid IN $community_uuids
|
WHERE c.uuid IN $community_uuids
|
||||||
|
|
@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities(
|
||||||
|
|
||||||
|
|
||||||
async def get_embeddings_for_edges(
|
async def get_embeddings_for_edges(
|
||||||
driver: AsyncDriver, edges: list[EntityEdge]
|
driver: GraphDriver, edges: list[EntityEdge]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||||
WHERE e.uuid IN $edge_uuids
|
WHERE e.uuid IN $edge_uuids
|
||||||
|
|
|
||||||
|
|
@ -20,22 +20,24 @@ from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
|
||||||
from numpy import dot, sqrt
|
from numpy import dot, sqrt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
|
from graphiti_core.graph_queries import (
|
||||||
|
get_entity_edge_save_bulk_query,
|
||||||
|
get_entity_node_save_bulk_query,
|
||||||
|
)
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
ENTITY_EDGE_SAVE_BULK,
|
|
||||||
EPISODIC_EDGE_SAVE_BULK,
|
EPISODIC_EDGE_SAVE_BULK,
|
||||||
)
|
)
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
ENTITY_NODE_SAVE_BULK,
|
|
||||||
EPISODIC_NODE_SAVE_BULK,
|
EPISODIC_NODE_SAVE_BULK,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
|
@ -73,7 +75,7 @@ class RawEpisode(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_previous_episodes_bulk(
|
async def retrieve_previous_episodes_bulk(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: GraphDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||||
previous_episodes_list = await semaphore_gather(
|
previous_episodes_list = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk(
|
||||||
|
|
||||||
|
|
||||||
async def add_nodes_and_edges_bulk(
|
async def add_nodes_and_edges_bulk(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
episodic_nodes: list[EpisodicNode],
|
episodic_nodes: list[EpisodicNode],
|
||||||
episodic_edges: list[EpisodicEdge],
|
episodic_edges: list[EpisodicEdge],
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
entity_edges: list[EntityEdge],
|
entity_edges: list[EntityEdge],
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
):
|
):
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
session = driver.session(database=DEFAULT_DATABASE)
|
||||||
|
try:
|
||||||
await session.execute_write(
|
await session.execute_write(
|
||||||
add_nodes_and_edges_bulk_tx,
|
add_nodes_and_edges_bulk_tx,
|
||||||
episodic_nodes,
|
episodic_nodes,
|
||||||
|
|
@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk(
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
entity_edges,
|
entity_edges,
|
||||||
embedder,
|
embedder,
|
||||||
|
driver=driver,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
async def add_nodes_and_edges_bulk_tx(
|
async def add_nodes_and_edges_bulk_tx(
|
||||||
tx: AsyncManagedTransaction,
|
tx: GraphDriverSession,
|
||||||
episodic_nodes: list[EpisodicNode],
|
episodic_nodes: list[EpisodicNode],
|
||||||
episodic_edges: list[EpisodicEdge],
|
episodic_edges: list[EpisodicEdge],
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
entity_edges: list[EntityEdge],
|
entity_edges: list[EntityEdge],
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
|
driver: GraphDriver,
|
||||||
):
|
):
|
||||||
episodes = [dict(episode) for episode in episodic_nodes]
|
episodes = [dict(episode) for episode in episodic_nodes]
|
||||||
for episode in episodes:
|
for episode in episodes:
|
||||||
|
|
@ -160,11 +167,13 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
edges.append(edge_data)
|
edges.append(edge_data)
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
||||||
|
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||||
)
|
)
|
||||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
|
||||||
|
await tx.run(entity_edge_save_bulk, entity_edges=edges)
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
|
@ -211,7 +220,7 @@ async def extract_nodes_and_edges_bulk(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_nodes_bulk(
|
async def dedupe_nodes_bulk(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_nodes: list[EntityNode],
|
extracted_nodes: list[EntityNode],
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
|
|
@ -247,7 +256,7 @@ async def dedupe_nodes_bulk(
|
||||||
|
|
||||||
|
|
||||||
async def dedupe_edges_bulk(
|
async def dedupe_edges_bulk(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# First compress edges
|
# First compress edges
|
||||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.edges import CommunityEdge
|
from graphiti_core.edges import CommunityEdge
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
|
|
@ -26,7 +26,7 @@ class Neighbor(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def get_community_clusters(
|
async def get_community_clusters(
|
||||||
driver: AsyncDriver, group_ids: list[str] | None
|
driver: GraphDriver, group_ids: list[str] | None
|
||||||
) -> list[list[EntityNode]]:
|
) -> list[list[EntityNode]]:
|
||||||
community_clusters: list[list[EntityNode]] = []
|
community_clusters: list[list[EntityNode]] = []
|
||||||
|
|
||||||
|
|
@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
||||||
community_candidates: dict[int, int] = defaultdict(int)
|
community_candidates: dict[int, int] = defaultdict(int)
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
||||||
|
|
||||||
community_lst = [
|
community_lst = [
|
||||||
(count, community) for community, count in community_candidates.items()
|
(count, community) for community, count in community_candidates.items()
|
||||||
]
|
]
|
||||||
|
|
@ -194,7 +193,7 @@ async def build_community(
|
||||||
|
|
||||||
|
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
|
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community_clusters = await get_community_clusters(driver, group_ids)
|
community_clusters = await get_community_clusters(driver, group_ids)
|
||||||
|
|
||||||
|
|
@ -219,7 +218,7 @@ async def build_communities(
|
||||||
return community_nodes, community_edges
|
return community_nodes, community_edges
|
||||||
|
|
||||||
|
|
||||||
async def remove_communities(driver: AsyncDriver):
|
async def remove_communities(driver: GraphDriver):
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
|
|
@ -230,10 +229,10 @@ async def remove_communities(driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def determine_entity_community(
|
async def determine_entity_community(
|
||||||
driver: AsyncDriver, entity: EntityNode
|
driver: GraphDriver, entity: EntityNode
|
||||||
) -> tuple[CommunityNode | None, bool]:
|
) -> tuple[CommunityNode | None, bool]:
|
||||||
# Check if the node is already part of a community
|
# Check if the node is already part of a community
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -251,7 +250,7 @@ async def determine_entity_community(
|
||||||
return get_community_node_from_record(records[0]), False
|
return get_community_node_from_record(records[0]), False
|
||||||
|
|
||||||
# If the node has no community, add it to the mode community of surrounding entities
|
# If the node has no community, add it to the mode community of surrounding entities
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -291,7 +290,7 @@ async def determine_entity_community(
|
||||||
|
|
||||||
|
|
||||||
async def update_community(
|
async def update_community(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||||
):
|
):
|
||||||
community, is_new = await determine_entity_community(driver, entity)
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -260,7 +260,6 @@ async def resolve_extracted_edges(
|
||||||
driver = clients.driver
|
driver = clients.driver
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
|
|
||||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||||
|
|
||||||
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,10 @@ limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||||
|
|
||||||
|
|
@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||||
if delete_existing:
|
if delete_existing:
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
|
|
@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||||
for name in index_names
|
for name in index_names
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||||
|
|
||||||
range_indices: list[LiteralString] = [
|
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
||||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
||||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
||||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
||||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
||||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
||||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
||||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
||||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
||||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
||||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
||||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
||||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
||||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
||||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
||||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
||||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
||||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
||||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
||||||
]
|
|
||||||
|
|
||||||
fulltext_indices: list[LiteralString] = [
|
|
||||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
||||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
||||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
||||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
||||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
||||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
||||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
||||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
||||||
]
|
|
||||||
|
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
|
|
@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
|
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
|
|
||||||
async def delete_all(tx):
|
async def delete_all(tx):
|
||||||
|
|
@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
driver: AsyncDriver,
|
driver: GraphDriver,
|
||||||
reference_time: datetime,
|
reference_time: datetime,
|
||||||
last_n: int = EPISODE_WINDOW_LEN,
|
last_n: int = EPISODE_WINDOW_LEN,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
|
@ -123,7 +94,7 @@ async def retrieve_episodes(
|
||||||
Retrieve the last n episodic nodes from the graph.
|
Retrieve the last n episodic nodes from the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
driver (AsyncDriver): The Neo4j driver instance.
|
driver (Driver): The Neo4j driver instance.
|
||||||
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
|
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
|
||||||
less than or equal to this reference_time will be retrieved. This allows for
|
less than or equal to this reference_time will be retrieved. This allows for
|
||||||
querying the graph's state at a specific point in time.
|
querying the graph's state at a specific point in time.
|
||||||
|
|
@ -140,8 +111,8 @@ async def retrieve_episodes(
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||||
"""
|
"""
|
||||||
+ group_id_filter
|
+ group_id_filter
|
||||||
+ source_filter
|
+ source_filter
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -157,8 +128,7 @@ async def retrieve_episodes(
|
||||||
LIMIT $num_episodes
|
LIMIT $num_episodes
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
result, _, _ = await driver.execute_query(
|
||||||
result = await driver.execute_query(
|
|
||||||
query,
|
query,
|
||||||
reference_time=reference_time,
|
reference_time=reference_time,
|
||||||
source=source.name if source is not None else None,
|
source=source.name if source is not None else None,
|
||||||
|
|
@ -166,6 +136,7 @@ async def retrieve_episodes(
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
content=record['content'],
|
content=record['content'],
|
||||||
|
|
@ -179,6 +150,6 @@ async def retrieve_episodes(
|
||||||
name=record['name'],
|
name=record['name'],
|
||||||
source_description=record['source_description'],
|
source_description=record['source_description'],
|
||||||
)
|
)
|
||||||
for record in result.records
|
for record in result
|
||||||
]
|
]
|
||||||
return list(reversed(episodes)) # Return in chronological order
|
return list(reversed(episodes)) # Return in chronological order
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,6 @@ async def extract_attributes_from_nodes(
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
|
|
||||||
updated_nodes: list[EntityNode] = await semaphore_gather(
|
updated_nodes: list[EntityNode] = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
extract_attributes_from_node(
|
extract_attributes_from_node(
|
||||||
|
|
|
||||||
46
poetry.lock
generated
46
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
|
|
@ -332,12 +332,12 @@ version = "5.0.1"
|
||||||
description = "Timeout context manager for asyncio programs"
|
description = "Timeout context manager for asyncio programs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["main", "dev"]
|
||||||
markers = "python_version < \"3.11\""
|
|
||||||
files = [
|
files = [
|
||||||
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
||||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||||
]
|
]
|
||||||
|
markers = {main = "python_full_version < \"3.11.3\"", dev = "python_version == \"3.10\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "attrs"
|
name = "attrs"
|
||||||
|
|
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
groups = ["main", "dev"]
|
groups = ["main", "dev"]
|
||||||
markers = "python_version < \"3.11\""
|
markers = "python_version == \"3.10\""
|
||||||
files = [
|
files = [
|
||||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||||
|
|
@ -798,6 +798,20 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""]
|
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "falkordb"
|
||||||
|
version = "1.1.2"
|
||||||
|
description = "Python client for interacting with FalkorDB database"
|
||||||
|
optional = false
|
||||||
|
python-versions = "<4.0,>=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "falkordb-1.1.2.tar.gz", hash = "sha256:db76c97efe14a56c3d65c61b966a42b874e1c78a8fb6808de3f61f4314b04023"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
redis = ">=5.0.1,<6.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastjsonschema"
|
name = "fastjsonschema"
|
||||||
version = "2.21.1"
|
version = "2.21.1"
|
||||||
|
|
@ -2665,7 +2679,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "platform_python_implementation != \"PyPy\""
|
|
||||||
files = [
|
files = [
|
||||||
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
||||||
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
||||||
|
|
@ -3691,6 +3704,25 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redis"
|
||||||
|
version = "5.2.1"
|
||||||
|
description = "Python client for Redis database and key-value store"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
|
||||||
|
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
hiredis = ["hiredis (>=3.0.0)"]
|
||||||
|
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "referencing"
|
name = "referencing"
|
||||||
version = "0.36.2"
|
version = "0.36.2"
|
||||||
|
|
@ -4498,7 +4530,7 @@ description = "A lil' TOML parser"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "python_version < \"3.11\""
|
markers = "python_version == \"3.10\""
|
||||||
files = [
|
files = [
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||||
|
|
@ -5356,4 +5388,4 @@ groq = ["groq"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4"
|
python-versions = ">=3.10,<4"
|
||||||
content-hash = "814d067fd2959bfe2db58a22637d86580b66d96f34c433852c67d02089d750ab"
|
content-hash = "2e02a10a6493f7564b86d5d0d09b4cf718004808e115af39550b9ee87c296fb4"
|
||||||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
||||||
"tenacity>=9.0.0",
|
"tenacity>=9.0.0",
|
||||||
"numpy>=1.0.0",
|
"numpy>=1.0.0",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
|
"falkordb (>=1.1.2,<2.0.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue