Add support for falkordb (#575)
* [wip] add support for falkordb * updates * fix-async * progress * fix-issues * rm-date-handler * red-code * rm-uns-try * fix-exm * rm-un-lines * fix-comments * fix-se-utils * fix-falkor-readme * fix-falkor-cosine-score * update-falkor-ver * fix-vec-sim * min-updates * make format * update graph driver abstraction * poetry lock * updates * linter * Update graphiti_core/search/search_utils.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: Dudi Zimberknopf <zimber.dudi@gmail.com> Co-authored-by: Gal Shubeli <galshubeli93@gmail.com> Co-authored-by: Gal Shubeli <124919062+galshubeli@users.noreply.github.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d7e1a4b79
commit
14146dc46f
27 changed files with 1131 additions and 348 deletions
|
|
@ -1,8 +1,17 @@
|
|||
OPENAI_API_KEY=
|
||||
|
||||
# Neo4j database connection
|
||||
NEO4J_URI=
|
||||
NEO4J_PORT=
|
||||
NEO4J_USER=
|
||||
NEO4J_PASSWORD=
|
||||
|
||||
# FalkorDB database connection
|
||||
FALKORDB_URI=
|
||||
FALKORDB_PORT=
|
||||
FALKORDB_USER=
|
||||
FALKORDB_PASSWORD=
|
||||
|
||||
DEFAULT_DATABASE=
|
||||
USE_PARALLEL_RUNTIME=
|
||||
SEMAPHORE_LIMIT=
|
||||
|
|
|
|||
|
|
@ -64,10 +64,11 @@ Once you've found an issue tagged with "good first issue" or "help wanted," or p
|
|||
export TEST_OPENAI_API_KEY=...
|
||||
export TEST_OPENAI_MODEL=...
|
||||
export TEST_ANTHROPIC_API_KEY=...
|
||||
|
||||
export NEO4J_URI=neo4j://...
|
||||
export NEO4J_USER=...
|
||||
export NEO4J_PASSWORD=...
|
||||
|
||||
# For Neo4j
|
||||
export TEST_URI=neo4j://...
|
||||
export TEST_USER=...
|
||||
export TEST_PASSWORD=...
|
||||
```
|
||||
|
||||
## Making Changes
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
|||
Requirements:
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
||||
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
||||
- OpenAI API key (for LLM inference and embedding)
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
|
|
|||
|
|
@ -76,9 +76,7 @@ async def main():
|
|||
group_id = str(uuid4())
|
||||
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
episodes = await client.retrieve_episodes(
|
||||
message.actual_timestamp, 3, group_ids=['podcast']
|
||||
)
|
||||
episodes = await client.retrieve_episodes(message.actual_timestamp, 3, group_ids=[group_id])
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
await client.add_episode(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This example demonstrates the basic functionality of Graphiti, including:
|
||||
|
||||
1. Connecting to a Neo4j database
|
||||
1. Connecting to a Neo4j or FalkorDB database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph
|
||||
4. Searching the graph with semantic and keyword matching
|
||||
|
|
@ -11,10 +11,14 @@ This example demonstrates the basic functionality of Graphiti, including:
|
|||
|
||||
## Prerequisites
|
||||
|
||||
- Neo4j Desktop installed and running
|
||||
- A local DBMS created and started in Neo4j Desktop
|
||||
- Python 3.9+
|
||||
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
|
||||
- Python 3.9+
|
||||
- OpenAI API key (set as `OPENAI_API_KEY` environment variable)
|
||||
- **For Neo4j**:
|
||||
- Neo4j Desktop installed and running
|
||||
- A local DBMS created and started in Neo4j Desktop
|
||||
- **For FalkorDB**:
|
||||
- FalkorDB server running (see [FalkorDB documentation](https://falkordb.com/docs/) for setup)
|
||||
|
||||
|
||||
## Setup Instructions
|
||||
|
||||
|
|
@ -34,17 +38,23 @@ export OPENAI_API_KEY=your_openai_api_key
|
|||
export NEO4J_URI=bolt://localhost:7687
|
||||
export NEO4J_USER=neo4j
|
||||
export NEO4J_PASSWORD=password
|
||||
|
||||
# Optional FalkorDB connection parameters (defaults shown)
|
||||
export FALKORDB_URI=falkor://localhost:6379
|
||||
```
|
||||
|
||||
3. Run the example:
|
||||
|
||||
```bash
|
||||
python quickstart.py
|
||||
python quickstart_neo4j.py
|
||||
|
||||
# For FalkorDB
|
||||
python quickstart_falkordb.py
|
||||
```
|
||||
|
||||
## What This Example Demonstrates
|
||||
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j
|
||||
- **Graph Initialization**: Setting up the Graphiti indices and constraints in Neo4j or FalkorDB
|
||||
- **Adding Episodes**: Adding text content that will be analyzed and converted into knowledge graph nodes and edges
|
||||
- **Edge Search Functionality**: Performing hybrid searches that combine semantic similarity and BM25 retrieval to find relationships (edges)
|
||||
- **Graph-Aware Search**: Using the source node UUID from the top search result to rerank additional search results based on graph distance
|
||||
|
|
|
|||
240
examples/quickstart/quickstart_falkordb.py
Normal file
240
examples/quickstart/quickstart_falkordb.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from logging import INFO
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
|
||||
#################################################
|
||||
# CONFIGURATION
|
||||
#################################################
|
||||
# Set up logging and environment variables for
|
||||
# connecting to FalkorDB database
|
||||
#################################################
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# FalkorDB connection parameters
|
||||
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
|
||||
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
|
||||
|
||||
if not falkor_uri:
|
||||
raise ValueError('FALKORDB_URI must be set')
|
||||
|
||||
|
||||
async def main():
|
||||
#################################################
|
||||
# INITIALIZATION
|
||||
#################################################
|
||||
# Connect to FalkorDB and set up Graphiti indices
|
||||
# This is required before using other Graphiti
|
||||
# functionality
|
||||
#################################################
|
||||
|
||||
# Initialize Graphiti with FalkorDB connection
|
||||
graphiti = Graphiti(falkor_uri)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
#################################################
|
||||
# ADDING EPISODES
|
||||
#################################################
|
||||
# Episodes are the primary units of information
|
||||
# in Graphiti. They can be text or structured JSON
|
||||
# and are automatically processed to extract entities
|
||||
# and relationships.
|
||||
#################################################
|
||||
|
||||
# Example: Add Episodes
|
||||
# Episodes list containing both text and JSON episodes
|
||||
episodes = [
|
||||
{
|
||||
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||
'the district attorney for San Francisco.',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
'content': {
|
||||
'name': 'Gavin Newsom',
|
||||
'position': 'Governor',
|
||||
'state': 'California',
|
||||
'previous_role': 'Lieutenant Governor',
|
||||
'previous_location': 'San Francisco',
|
||||
},
|
||||
'type': EpisodeType.json,
|
||||
'description': 'podcast metadata',
|
||||
},
|
||||
{
|
||||
'content': {
|
||||
'name': 'Gavin Newsom',
|
||||
'position': 'Governor',
|
||||
'term_start': 'January 7, 2019',
|
||||
'term_end': 'Present',
|
||||
},
|
||||
'type': EpisodeType.json,
|
||||
'description': 'podcast metadata',
|
||||
},
|
||||
]
|
||||
|
||||
# Add episodes to the graph
|
||||
for i, episode in enumerate(episodes):
|
||||
await graphiti.add_episode(
|
||||
name=f'Freakonomics Radio {i}',
|
||||
episode_body=episode['content']
|
||||
if isinstance(episode['content'], str)
|
||||
else json.dumps(episode['content']),
|
||||
source=episode['type'],
|
||||
source_description=episode['description'],
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
|
||||
|
||||
#################################################
|
||||
# BASIC SEARCH
|
||||
#################################################
|
||||
# The simplest way to retrieve relationships (edges)
|
||||
# from Graphiti is using the search method, which
|
||||
# performs a hybrid search combining semantic
|
||||
# similarity and BM25 text retrieval.
|
||||
#################################################
|
||||
|
||||
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||
print("\nSearching for: 'Who was the California Attorney General?'")
|
||||
results = await graphiti.search('Who was the California Attorney General?')
|
||||
|
||||
# Print search results
|
||||
print('\nSearch Results:')
|
||||
for result in results:
|
||||
print(f'UUID: {result.uuid}')
|
||||
print(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
print(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
print(f'Valid until: {result.invalid_at}')
|
||||
print('---')
|
||||
|
||||
#################################################
|
||||
# CENTER NODE SEARCH
|
||||
#################################################
|
||||
# For more contextually relevant results, you can
|
||||
# use a center node to rerank search results based
|
||||
# on their graph distance to a specific node
|
||||
#################################################
|
||||
|
||||
# Use the top search result's UUID as the center node for reranking
|
||||
if results and len(results) > 0:
|
||||
# Get the source node UUID from the top result
|
||||
center_node_uuid = results[0].source_node_uuid
|
||||
|
||||
print('\nReranking search results based on graph distance:')
|
||||
print(f'Using center node UUID: {center_node_uuid}')
|
||||
|
||||
reranked_results = await graphiti.search(
|
||||
'Who was the California Attorney General?', center_node_uuid=center_node_uuid
|
||||
)
|
||||
|
||||
# Print reranked search results
|
||||
print('\nReranked Search Results:')
|
||||
for result in reranked_results:
|
||||
print(f'UUID: {result.uuid}')
|
||||
print(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
print(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
print(f'Valid until: {result.invalid_at}')
|
||||
print('---')
|
||||
else:
|
||||
print('No results found in the initial search to use as center node.')
|
||||
|
||||
#################################################
|
||||
# NODE SEARCH USING SEARCH RECIPES
|
||||
#################################################
|
||||
# Graphiti provides predefined search recipes
|
||||
# optimized for different search scenarios.
|
||||
# Here we use NODE_HYBRID_SEARCH_RRF for retrieving
|
||||
# nodes directly instead of edges.
|
||||
#################################################
|
||||
|
||||
# Example: Perform a node search using _search method with standard recipes
|
||||
print(
|
||||
'\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
|
||||
)
|
||||
|
||||
# Use a predefined search configuration recipe and modify its limit
|
||||
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
||||
node_search_config.limit = 5 # Limit to 5 results
|
||||
|
||||
# Execute the node search
|
||||
node_search_results = await graphiti._search(
|
||||
query='California Governor',
|
||||
config=node_search_config,
|
||||
)
|
||||
|
||||
# Print node search results
|
||||
print('\nNode Search Results:')
|
||||
for node in node_search_results.nodes:
|
||||
print(f'Node UUID: {node.uuid}')
|
||||
print(f'Node Name: {node.name}')
|
||||
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
|
||||
print(f'Content Summary: {node_summary}')
|
||||
print(f'Node Labels: {", ".join(node.labels)}')
|
||||
print(f'Created At: {node.created_at}')
|
||||
if hasattr(node, 'attributes') and node.attributes:
|
||||
print('Attributes:')
|
||||
for key, value in node.attributes.items():
|
||||
print(f' {key}: {value}')
|
||||
print('---')
|
||||
|
||||
finally:
|
||||
#################################################
|
||||
# CLEANUP
|
||||
#################################################
|
||||
# Always close the connection to FalkorDB when
|
||||
# finished to properly release resources
|
||||
#################################################
|
||||
|
||||
# Close the connection
|
||||
await graphiti.close()
|
||||
print('\nConnection closed')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
17
graphiti_core/driver/__init__.py
Normal file
17
graphiti_core/driver/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
__all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
|
||||
81
graphiti_core/driver/driver.py
Normal file
81
graphiti_core/driver/driver.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
@abstractmethod
|
||||
async def run(self, query: str, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class GraphDriver(ABC):
|
||||
provider: str
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# class GraphDriver:
|
||||
# _driver: GraphClient
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# uri: str,
|
||||
# user: str,
|
||||
# password: str,
|
||||
# ):
|
||||
# if uri.startswith('falkor'):
|
||||
# # FalkorDB
|
||||
# self._driver = FalkorClient(uri, user, password)
|
||||
# self.provider = 'falkordb'
|
||||
# else:
|
||||
# # Neo4j
|
||||
# self._driver = Neo4jClient(uri, user, password)
|
||||
# self.provider = 'neo4j'
|
||||
#
|
||||
# def execute_query(self, cypher_query_, **kwargs: Any) -> Coroutine:
|
||||
# return self._driver.execute_query(cypher_query_, **kwargs)
|
||||
#
|
||||
# async def close(self):
|
||||
# return await self._driver.close()
|
||||
#
|
||||
# def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
# return self._driver.delete_all_indexes(database_)
|
||||
#
|
||||
# def session(self, database: str) -> GraphClientSession:
|
||||
# return self._driver.session(database)
|
||||
132
graphiti_core/driver/falkordb_driver.py
Normal file
132
graphiti_core/driver/falkordb_driver.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from falkordb import Graph as FalkorGraph
|
||||
from falkordb.asyncio import FalkorDB
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FalkorClientSession(GraphDriverSession):
|
||||
def __init__(self, graph: FalkorGraph):
|
||||
self.graph = graph
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# No cleanup needed for Falkor, but method must exist
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
# No explicit close needed for FalkorDB, but method must exist
|
||||
pass
|
||||
|
||||
async def execute_write(self, func, *args, **kwargs):
|
||||
# Directly await the provided async function with `self` as the transaction/session
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
async def run(self, cypher_query_: str | list, **kwargs: Any) -> Any:
|
||||
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
|
||||
if isinstance(cypher_query_, list):
|
||||
for cypher, params in cypher_query_:
|
||||
params = convert_datetimes_to_strings(params)
|
||||
await self.graph.query(str(cypher), params)
|
||||
else:
|
||||
params = dict(kwargs)
|
||||
params = convert_datetimes_to_strings(params)
|
||||
await self.graph.query(str(cypher_query_), params)
|
||||
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
||||
return None
|
||||
|
||||
|
||||
class FalkorDriver(GraphDriver):
|
||||
provider: str = 'falkordb'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str,
|
||||
password: str,
|
||||
):
|
||||
super().__init__()
|
||||
if user and password:
|
||||
uri_parts = uri.split('://', 1)
|
||||
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
||||
|
||||
self.client = FalkorDB.from_url(
|
||||
url=uri,
|
||||
)
|
||||
|
||||
def _get_graph(self, graph_name: str) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
||||
if graph_name is None:
|
||||
graph_name = 'DEFAULT_DATABASE'
|
||||
return self.client.select_graph(graph_name)
|
||||
|
||||
async def execute_query(self, cypher_query_, **kwargs: Any):
|
||||
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
|
||||
graph = self._get_graph(graph_name)
|
||||
|
||||
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
||||
params = convert_datetimes_to_strings(dict(kwargs))
|
||||
|
||||
try:
|
||||
result = await graph.query(cypher_query_, params)
|
||||
except Exception as e:
|
||||
if 'already indexed' in str(e):
|
||||
# check if index already exists
|
||||
logger.info(f'Index already exists: {e}')
|
||||
return None
|
||||
logger.error(f'Error executing FalkorDB query: {e}')
|
||||
raise
|
||||
|
||||
# Convert the result header to a list of strings
|
||||
header = [h[1].decode('utf-8') for h in result.header]
|
||||
return result.result_set, header, None
|
||||
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
return FalkorClientSession(self._get_graph(database))
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.client.connection.close()
|
||||
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
return self.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
)
|
||||
|
||||
|
||||
def convert_datetimes_to_strings(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_datetimes_to_strings(item) for item in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
60
graphiti_core/driver/neo4j_driver.py
Normal file
60
graphiti_core/driver/neo4j_driver.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any, LiteralString
|
||||
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider: str = 'neo4j'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str,
|
||||
password: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.client = AsyncGraphDatabase.driver(
|
||||
uri=uri,
|
||||
auth=(user, password),
|
||||
)
|
||||
|
||||
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
|
||||
params = kwargs.pop('params', None)
|
||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def session(self, database: str) -> GraphDriverSession:
|
||||
return self.client.session(database=database) # type: ignore
|
||||
|
||||
async def close(self) -> None:
|
||||
return await self.client.close()
|
||||
|
||||
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
)
|
||||
|
|
@ -21,10 +21,10 @@ from time import time
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
|
|
@ -62,9 +62,9 @@ class Edge(BaseModel, ABC):
|
|||
created_at: datetime
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
async def delete(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
|
|
@ -87,11 +87,11 @@ class Edge(BaseModel, ABC):
|
|||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
||||
|
||||
class EpisodicEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
EPISODIC_EDGE_SAVE,
|
||||
episode_uuid=self.source_node_uuid,
|
||||
|
|
@ -102,12 +102,12 @@ class EpisodicEdge(Edge):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
|
|
@ -130,7 +130,7 @@ class EpisodicEdge(Edge):
|
|||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
|
|
@ -156,7 +156,7 @@ class EpisodicEdge(Edge):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -226,7 +226,7 @@ class EntityEdge(Edge):
|
|||
|
||||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: AsyncDriver):
|
||||
async def load_fact_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
|
|
@ -240,7 +240,7 @@ class EntityEdge(Edge):
|
|||
|
||||
self.fact_embedding = records[0]['fact_embedding']
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
edge_data: dict[str, Any] = {
|
||||
'source_uuid': self.source_node_uuid,
|
||||
'target_uuid': self.target_node_uuid,
|
||||
|
|
@ -264,12 +264,12 @@ class EntityEdge(Edge):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
|
|
@ -287,7 +287,7 @@ class EntityEdge(Edge):
|
|||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
if len(uuids) == 0:
|
||||
return []
|
||||
|
||||
|
|
@ -309,7 +309,7 @@ class EntityEdge(Edge):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -342,11 +342,11 @@ class EntityEdge(Edge):
|
|||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -359,7 +359,7 @@ class EntityEdge(Edge):
|
|||
|
||||
|
||||
class CommunityEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
community_uuid=self.source_node_uuid,
|
||||
|
|
@ -370,12 +370,12 @@ class CommunityEdge(Edge):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
||||
|
|
@ -396,7 +396,7 @@ class CommunityEdge(Edge):
|
|||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
|
|
@ -420,7 +420,7 @@ class CommunityEdge(Edge):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
episodes=record['episodes'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
|
|
@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
|
|||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
147
graphiti_core/graph_queries.py
Normal file
147
graphiti_core/graph_queries.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""
|
||||
Database query utilities for different graph database backends.
|
||||
|
||||
This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
||||
supporting index creation, fulltext search, and bulk operations.
|
||||
"""
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
ENTITY_EDGE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
ENTITY_NODE_SAVE_BULK,
|
||||
)
|
||||
|
||||
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
||||
NEO4J_TO_FALKORDB_MAPPING = {
|
||||
'node_name_and_summary': 'Entity',
|
||||
'community_name': 'Community',
|
||||
'episode_content': 'Episodic',
|
||||
'edge_name_and_fact': 'RELATES_TO',
|
||||
}
|
||||
|
||||
|
||||
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||
if db_type == 'falkordb':
|
||||
return [
|
||||
# Entity node
|
||||
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
||||
# Episodic node
|
||||
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
|
||||
# Community node
|
||||
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
|
||||
# RELATES_TO edge
|
||||
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
|
||||
# MENTIONS edge
|
||||
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
|
||||
# HAS_MEMBER edge
|
||||
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
]
|
||||
else:
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||
]
|
||||
|
||||
|
||||
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||
if db_type == 'falkordb':
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||
]
|
||||
|
||||
|
||||
def get_nodes_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||
if db_type == 'falkordb':
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||
else:
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
|
||||
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
||||
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
||||
else:
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
def get_relationships_query(db_type: str = 'neo4j', name: str = None, query: str = None) -> str:
|
||||
if db_type == 'falkordb':
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||
else:
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
queries = []
|
||||
for node in nodes:
|
||||
for label in node['labels']:
|
||||
queries.append(
|
||||
(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {{uuid: node.uuid}})
|
||||
SET n:{label}
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = vecf32(node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
""",
|
||||
{'nodes': [node]},
|
||||
)
|
||||
)
|
||||
return queries
|
||||
else:
|
||||
return ENTITY_NODE_SAVE_BULK
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||
WITH r, edge
|
||||
RETURN edge.uuid AS uuid"""
|
||||
else:
|
||||
return ENTITY_EDGE_SAVE_BULK
|
||||
|
|
@ -19,12 +19,13 @@ from datetime import datetime
|
|||
from time import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
|
|
@ -94,12 +95,13 @@ class Graphiti:
|
|||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str,
|
||||
password: str,
|
||||
user: str = None,
|
||||
password: str = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
graph_driver: GraphDriver = None,
|
||||
):
|
||||
"""
|
||||
Initialize a Graphiti instance.
|
||||
|
|
@ -137,7 +139,9 @@ class Graphiti:
|
|||
Make sure to set the OPENAI_API_KEY environment variable before initializing
|
||||
Graphiti if you're using the default OpenAIClient.
|
||||
"""
|
||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||
|
||||
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
||||
|
||||
self.database = DEFAULT_DATABASE
|
||||
self.store_raw_episode_content = store_raw_episode_content
|
||||
if llm_client:
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from graphiti_core.cross_encoder import CrossEncoderClient
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
|
||||
|
||||
class GraphitiClients(BaseModel):
|
||||
driver: AsyncDriver
|
||||
driver: GraphDriver
|
||||
llm_client: LLMClient
|
||||
embedder: EmbedderClient
|
||||
cross_encoder: CrossEncoderClient
|
||||
|
|
|
|||
|
|
@ -38,8 +38,14 @@ RUNTIME_QUERY: LiteralString = (
|
|||
)
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
||||
return (
|
||||
neo_date.to_native()
|
||||
if isinstance(neo_date, neo4j_time.DateTime)
|
||||
else datetime.fromisoformat(neo_date)
|
||||
if neo_date
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def lucene_sanitize(query: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ from time import time
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_SAVE,
|
||||
ENTITY_NODE_SAVE,
|
||||
|
|
@ -94,9 +94,9 @@ class Node(BaseModel, ABC):
|
|||
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
async def delete(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
|
|
@ -119,7 +119,7 @@ class Node(BaseModel, ABC):
|
|||
return False
|
||||
|
||||
@classmethod
|
||||
async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
|
||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||
|
|
@ -132,10 +132,10 @@ class Node(BaseModel, ABC):
|
|||
return 'SUCCESS'
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
|
|
@ -150,7 +150,7 @@ class EpisodicNode(Node):
|
|||
default_factory=list,
|
||||
)
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
EPISODIC_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
|
|
@ -165,12 +165,12 @@ class EpisodicNode(Node):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic {uuid: $uuid})
|
||||
|
|
@ -197,7 +197,7 @@ class EpisodicNode(Node):
|
|||
return episodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||
|
|
@ -224,7 +224,7 @@ class EpisodicNode(Node):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -263,7 +263,7 @@ class EpisodicNode(Node):
|
|||
return episodes
|
||||
|
||||
@classmethod
|
||||
async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
|
||||
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||
|
|
@ -304,7 +304,7 @@ class EntityNode(Node):
|
|||
|
||||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: AsyncDriver):
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
|
|
@ -318,7 +318,7 @@ class EntityNode(Node):
|
|||
|
||||
self.name_embedding = records[0]['name_embedding']
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
|
|
@ -337,16 +337,16 @@ class EntityNode(Node):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -364,7 +364,7 @@ class EntityNode(Node):
|
|||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||
|
|
@ -382,7 +382,7 @@ class EntityNode(Node):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -416,7 +416,7 @@ class CommunityNode(Node):
|
|||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
COMMUNITY_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
|
|
@ -428,7 +428,7 @@ class CommunityNode(Node):
|
|||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -441,7 +441,7 @@ class CommunityNode(Node):
|
|||
|
||||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: AsyncDriver):
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN c.name_embedding AS name_embedding
|
||||
|
|
@ -456,7 +456,7 @@ class CommunityNode(Node):
|
|||
self.name_embedding = records[0]['name_embedding']
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community {uuid: $uuid})
|
||||
|
|
@ -480,7 +480,7 @@ class CommunityNode(Node):
|
|||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.uuid IN $uuids
|
||||
|
|
@ -503,7 +503,7 @@ class CommunityNode(Node):
|
|||
@classmethod
|
||||
async def get_by_group_ids(
|
||||
cls,
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
group_ids: list[str],
|
||||
limit: int | None = None,
|
||||
uuid_cursor: str | None = None,
|
||||
|
|
@ -542,8 +542,8 @@ class CommunityNode(Node):
|
|||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||
return EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=record['created_at'].to_native().timestamp(),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
created_at=parse_db_date(record['created_at']).timestamp(),
|
||||
valid_at=(parse_db_date(record['valid_at'])),
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
|
|
@ -559,7 +559,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
labels=record['labels'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
attributes=record['attributes'],
|
||||
)
|
||||
|
|
@ -580,7 +580,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
created_at=parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,9 +18,8 @@ import logging
|
|||
from collections import defaultdict
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
|
|
@ -94,7 +93,7 @@ async def search(
|
|||
)
|
||||
|
||||
# if group_ids is empty, set it to None
|
||||
group_ids = group_ids if group_ids else None
|
||||
group_ids = group_ids if group_ids and group_ids != [''] else None
|
||||
edges, nodes, episodes, communities = await semaphore_gather(
|
||||
edge_search(
|
||||
driver,
|
||||
|
|
@ -160,7 +159,7 @@ async def search(
|
|||
|
||||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
|
|
@ -174,7 +173,6 @@ async def edge_search(
|
|||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityEdge]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -261,7 +259,7 @@ async def edge_search(
|
|||
|
||||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
|
|
@ -275,7 +273,6 @@ async def node_search(
|
|||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EntityNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -344,7 +341,7 @@ async def node_search(
|
|||
|
||||
|
||||
async def episode_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
_query_vector: list[float],
|
||||
|
|
@ -356,7 +353,6 @@ async def episode_search(
|
|||
) -> list[EpisodicNode]:
|
||||
if config is None:
|
||||
return []
|
||||
|
||||
search_results: list[list[EpisodicNode]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -392,7 +388,7 @@ async def episode_search(
|
|||
|
||||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
|
|
|
|||
|
|
@ -20,11 +20,16 @@ from time import time
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.graph_queries import (
|
||||
get_nodes_query,
|
||||
get_relationships_query,
|
||||
get_vector_cosine_func_query,
|
||||
)
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
RUNTIME_QUERY,
|
||||
|
|
@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
|
|||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
group_ids_filter_list = (
|
||||
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
|
||||
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
||||
)
|
||||
group_ids_filter = ''
|
||||
for f in group_ids_filter_list:
|
||||
|
|
@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|||
|
||||
|
||||
async def get_episodes_by_mentions(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
nodes: list[EntityNode],
|
||||
edges: list[EntityEdge],
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
|
|
@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
|
|||
|
||||
|
||||
async def get_mentioned_nodes(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
driver: GraphDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
||||
query = """
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
|
|
@ -106,7 +111,10 @@ async def get_mentioned_nodes(
|
|||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuids=episode_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
|
|
@ -118,11 +126,11 @@ async def get_mentioned_nodes(
|
|||
|
||||
|
||||
async def get_communities_by_nodes(
|
||||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
driver: GraphDriver, nodes: list[EntityNode]
|
||||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
||||
query = """
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
|
|
@ -130,7 +138,10 @@ async def get_communities_by_nodes(
|
|||
c.name AS name,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
uuids=node_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
|
|
@ -142,7 +153,7 @@ async def get_communities_by_nodes(
|
|||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
query: str,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
|
|
@ -155,34 +166,35 @@ async def edge_fulltext_search(
|
|||
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
cypher_query = Query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
WHERE r.group_id IN $group_ids"""
|
||||
query = (
|
||||
get_relationships_query(driver.provider, 'edge_name_and_fact', '$query')
|
||||
+ """
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
WHERE r.group_id IN $group_ids """
|
||||
+ filter_query
|
||||
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
"""
|
||||
+ """
|
||||
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
cypher_query,
|
||||
filter_params,
|
||||
query,
|
||||
params=filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -196,7 +208,7 @@ async def edge_fulltext_search(
|
|||
|
||||
|
||||
async def edge_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
|
|
@ -224,36 +236,38 @@ async def edge_similarity_search(
|
|||
if target_node_uuid is not None:
|
||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||
|
||||
query: LiteralString = (
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
+ """
|
||||
WITH DISTINCT r, """
|
||||
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
records, header, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
params=query_params,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
|
|
@ -264,13 +278,16 @@ async def edge_similarity_search(
|
|||
routing_='r',
|
||||
)
|
||||
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def edge_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
search_filter: SearchFilters,
|
||||
|
|
@ -282,14 +299,14 @@ async def edge_bfs_search(
|
|||
|
||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = Query(
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
|
|
@ -311,7 +328,7 @@ async def edge_bfs_search(
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
filter_params,
|
||||
params=filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
|
|
@ -325,7 +342,7 @@ async def edge_bfs_search(
|
|||
|
||||
|
||||
async def node_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
query: str,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
|
|
@ -335,38 +352,41 @@ async def node_fulltext_search(
|
|||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = (
|
||||
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
||||
+ """
|
||||
YIELD node AS n, score
|
||||
WITH n, score
|
||||
LIMIT $limit
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
records, header, _ = await driver.execute_query(
|
||||
query,
|
||||
filter_params,
|
||||
params=filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def node_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
search_vector: list[float],
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
|
|
@ -384,22 +404,28 @@ async def node_similarity_search(
|
|||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score"""
|
||||
WITH n, """
|
||||
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query_params,
|
||||
"""
|
||||
)
|
||||
|
||||
records, header, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -407,13 +433,15 @@ async def node_similarity_search(
|
|||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def node_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
search_filter: SearchFilters,
|
||||
bfs_max_depth: int,
|
||||
|
|
@ -425,18 +453,21 @@ async def node_bfs_search(
|
|||
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query = (
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
LIMIT $limit
|
||||
""",
|
||||
filter_params,
|
||||
"""
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
limit=limit,
|
||||
|
|
@ -449,7 +480,7 @@ async def node_bfs_search(
|
|||
|
||||
|
||||
async def episode_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
query: str,
|
||||
_search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
|
|
@ -460,9 +491,9 @@ async def episode_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
|
||||
query = (
|
||||
get_nodes_query(driver.provider, 'episode_content', '$query')
|
||||
+ """
|
||||
YIELD node AS episode, score
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.uuid = episode.uuid
|
||||
|
|
@ -478,7 +509,11 @@ async def episode_fulltext_search(
|
|||
e.entity_edges AS entity_edges
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -491,7 +526,7 @@ async def episode_fulltext_search(
|
|||
|
||||
|
||||
async def community_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
query: str,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
|
|
@ -501,9 +536,9 @@ async def community_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
|
||||
query = (
|
||||
get_nodes_query(driver.provider, 'community_name', '$query')
|
||||
+ """
|
||||
YIELD node AS comm, score
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
|
|
@ -513,7 +548,11 @@ async def community_fulltext_search(
|
|||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -526,7 +565,7 @@ async def community_fulltext_search(
|
|||
|
||||
|
||||
async def community_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
search_vector: list[float],
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
|
|
@ -540,14 +579,16 @@ async def community_similarity_search(
|
|||
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (comm:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WITH comm, """
|
||||
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
|
|
@ -557,7 +598,11 @@ async def community_similarity_search(
|
|||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -573,7 +618,7 @@ async def community_similarity_search(
|
|||
async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
search_filter: SearchFilters,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
|
|
@ -590,7 +635,7 @@ async def hybrid_node_search(
|
|||
A list of text queries to search for.
|
||||
embeddings : list[list[float]]
|
||||
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
||||
driver : AsyncDriver
|
||||
driver : GraphDriver
|
||||
The Neo4j driver instance for database operations.
|
||||
group_ids : list[str] | None, optional
|
||||
The list of group ids to retrieve nodes from.
|
||||
|
|
@ -645,7 +690,7 @@ async def hybrid_node_search(
|
|||
|
||||
|
||||
async def get_relevant_nodes(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
nodes: list[EntityNode],
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
|
|
@ -664,29 +709,33 @@ async def get_relevant_nodes(
|
|||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
|
||||
WITH node, n, """
|
||||
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
||||
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
|
||||
"""
|
||||
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
||||
+ """
|
||||
YIELD node AS m
|
||||
WHERE m.group_id = $group_id
|
||||
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
||||
|
||||
|
||||
WITH node,
|
||||
top_vector_nodes,
|
||||
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
||||
|
||||
|
||||
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
||||
|
||||
|
||||
UNWIND combined_nodes AS combined_node
|
||||
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
||||
|
||||
|
||||
RETURN
|
||||
node.uuid AS search_node_uuid,
|
||||
[x IN deduped_nodes | {
|
||||
|
|
@ -714,7 +763,7 @@ async def get_relevant_nodes(
|
|||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
params=query_params,
|
||||
nodes=query_nodes,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
|
|
@ -736,7 +785,7 @@ async def get_relevant_nodes(
|
|||
|
||||
|
||||
async def get_relevant_edges(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
edges: list[EntityEdge],
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
|
|
@ -752,43 +801,47 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
WITH e, edge, """
|
||||
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
params=query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
result['search_edge_uuid']: [
|
||||
get_entity_edge_from_record(record) for record in result['matches']
|
||||
|
|
@ -802,7 +855,7 @@ async def get_relevant_edges(
|
|||
|
||||
|
||||
async def get_edge_invalidation_candidates(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
edges: list[EntityEdge],
|
||||
search_filter: SearchFilters,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
|
|
@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
WITH edge, e, """
|
||||
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
WITH edge, e, score
|
||||
ORDER BY score DESC
|
||||
RETURN edge.uuid AS search_edge_uuid,
|
||||
collect({
|
||||
uuid: e.uuid,
|
||||
source_node_uuid: startNode(e).uuid,
|
||||
target_node_uuid: endNode(e).uuid,
|
||||
created_at: e.created_at,
|
||||
name: e.name,
|
||||
group_id: e.group_id,
|
||||
fact: e.fact,
|
||||
fact_embedding: e.fact_embedding,
|
||||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
query_params,
|
||||
params=query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
|
|
@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
node_uuids: list[str],
|
||||
center_node_uuid: str,
|
||||
min_score: float = 0,
|
||||
|
|
@ -894,21 +950,22 @@ async def node_distance_reranker(
|
|||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
query = """
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
|
||||
RETURN length(p) AS score, node_uuid AS uuid
|
||||
""")
|
||||
|
||||
path_results, _, _ = await driver.execute_query(
|
||||
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
||||
RETURN 1 AS score, node_uuid AS uuid
|
||||
"""
|
||||
results, header, _ = await driver.execute_query(
|
||||
query,
|
||||
node_uuids=filtered_uuids,
|
||||
center_uuid=center_node_uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
results = [dict(zip(header, row, strict=True)) for row in results]
|
||||
|
||||
for result in path_results:
|
||||
for result in results:
|
||||
uuid = result['uuid']
|
||||
score = result['score']
|
||||
scores[uuid] = score
|
||||
|
|
@ -929,19 +986,18 @@ async def node_distance_reranker(
|
|||
|
||||
|
||||
async def episode_mentions_reranker(
|
||||
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
|
||||
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
query = """
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||
RETURN count(*) AS score, n.uuid AS uuid
|
||||
""")
|
||||
|
||||
"""
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
node_uuids=sorted_uuids,
|
||||
|
|
@ -998,7 +1054,7 @@ def maximal_marginal_relevance(
|
|||
|
||||
|
||||
async def get_embeddings_for_nodes(
|
||||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
driver: GraphDriver, nodes: list[EntityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)
|
||||
WHERE n.uuid IN $node_uuids
|
||||
|
|
@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes(
|
|||
|
||||
|
||||
async def get_embeddings_for_communities(
|
||||
driver: AsyncDriver, communities: list[CommunityNode]
|
||||
driver: GraphDriver, communities: list[CommunityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (c:Community)
|
||||
WHERE c.uuid IN $community_uuids
|
||||
|
|
@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities(
|
|||
|
||||
|
||||
async def get_embeddings_for_edges(
|
||||
driver: AsyncDriver, edges: list[EntityEdge]
|
||||
driver: GraphDriver, edges: list[EntityEdge]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||
WHERE e.uuid IN $edge_uuids
|
||||
|
|
|
|||
|
|
@ -20,22 +20,24 @@ from collections import defaultdict
|
|||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||
from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graph_queries import (
|
||||
get_entity_edge_save_bulk_query,
|
||||
get_entity_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
ENTITY_EDGE_SAVE_BULK,
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
ENTITY_NODE_SAVE_BULK,
|
||||
EPISODIC_NODE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
|
|
@ -73,7 +75,7 @@ class RawEpisode(BaseModel):
|
|||
|
||||
|
||||
async def retrieve_previous_episodes_bulk(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
driver: GraphDriver, episodes: list[EpisodicNode]
|
||||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||||
previous_episodes_list = await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -91,14 +93,15 @@ async def retrieve_previous_episodes_bulk(
|
|||
|
||||
|
||||
async def add_nodes_and_edges_bulk(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
episodic_nodes: list[EpisodicNode],
|
||||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
embedder: EmbedderClient,
|
||||
):
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
session = driver.session(database=DEFAULT_DATABASE)
|
||||
try:
|
||||
await session.execute_write(
|
||||
add_nodes_and_edges_bulk_tx,
|
||||
episodic_nodes,
|
||||
|
|
@ -106,16 +109,20 @@ async def add_nodes_and_edges_bulk(
|
|||
entity_nodes,
|
||||
entity_edges,
|
||||
embedder,
|
||||
driver=driver,
|
||||
)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def add_nodes_and_edges_bulk_tx(
|
||||
tx: AsyncManagedTransaction,
|
||||
tx: GraphDriverSession,
|
||||
episodic_nodes: list[EpisodicNode],
|
||||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
embedder: EmbedderClient,
|
||||
driver: GraphDriver,
|
||||
):
|
||||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
|
|
@ -160,11 +167,13 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
edges.append(edge_data)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||
await tx.run(
|
||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||
)
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
||||
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
|
||||
await tx.run(entity_edge_save_bulk, entity_edges=edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
@ -211,7 +220,7 @@ async def extract_nodes_and_edges_bulk(
|
|||
|
||||
|
||||
async def dedupe_nodes_bulk(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
|
|
@ -247,7 +256,7 @@ async def dedupe_nodes_bulk(
|
|||
|
||||
|
||||
async def dedupe_edges_bulk(
|
||||
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
driver: GraphDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
# First compress edges
|
||||
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import asyncio
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
|
|
@ -26,7 +26,7 @@ class Neighbor(BaseModel):
|
|||
|
||||
|
||||
async def get_community_clusters(
|
||||
driver: AsyncDriver, group_ids: list[str] | None
|
||||
driver: GraphDriver, group_ids: list[str] | None
|
||||
) -> list[list[EntityNode]]:
|
||||
community_clusters: list[list[EntityNode]] = []
|
||||
|
||||
|
|
@ -95,7 +95,6 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
|||
community_candidates: dict[int, int] = defaultdict(int)
|
||||
for neighbor in neighbors:
|
||||
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
||||
|
||||
community_lst = [
|
||||
(count, community) for community, count in community_candidates.items()
|
||||
]
|
||||
|
|
@ -194,7 +193,7 @@ async def build_community(
|
|||
|
||||
|
||||
async def build_communities(
|
||||
driver: AsyncDriver, llm_client: LLMClient, group_ids: list[str] | None
|
||||
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
community_clusters = await get_community_clusters(driver, group_ids)
|
||||
|
||||
|
|
@ -219,7 +218,7 @@ async def build_communities(
|
|||
return community_nodes, community_edges
|
||||
|
||||
|
||||
async def remove_communities(driver: AsyncDriver):
|
||||
async def remove_communities(driver: GraphDriver):
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
|
|
@ -230,10 +229,10 @@ async def remove_communities(driver: AsyncDriver):
|
|||
|
||||
|
||||
async def determine_entity_community(
|
||||
driver: AsyncDriver, entity: EntityNode
|
||||
driver: GraphDriver, entity: EntityNode
|
||||
) -> tuple[CommunityNode | None, bool]:
|
||||
# Check if the node is already part of a community
|
||||
records, _, _ = await driver.execute_query(
|
||||
records, _, _ = driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
|
|
@ -251,7 +250,7 @@ async def determine_entity_community(
|
|||
return get_community_node_from_record(records[0]), False
|
||||
|
||||
# If the node has no community, add it to the mode community of surrounding entities
|
||||
records, _, _ = await driver.execute_query(
|
||||
records, _, _ = driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
|
|
@ -291,7 +290,7 @@ async def determine_entity_community(
|
|||
|
||||
|
||||
async def update_community(
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||
):
|
||||
community, is_new = await determine_entity_community(driver, entity)
|
||||
|
||||
|
|
|
|||
|
|
@ -260,7 +260,6 @@ async def resolve_extracted_edges(
|
|||
driver = clients.driver
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||
|
||||
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ limitations under the License.
|
|||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
|
|
@ -28,7 +29,7 @@ EPISODE_WINDOW_LEN = 3
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -47,39 +48,9 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|||
for name in index_names
|
||||
]
|
||||
)
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
range_indices: list[LiteralString] = [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||
]
|
||||
|
||||
fulltext_indices: list[LiteralString] = [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||
]
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
|
|
@ -94,7 +65,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|||
)
|
||||
|
||||
|
||||
async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
|
||||
async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
|
||||
async def delete_all(tx):
|
||||
|
|
@ -113,7 +84,7 @@ async def clear_data(driver: AsyncDriver, group_ids: list[str] | None = None):
|
|||
|
||||
|
||||
async def retrieve_episodes(
|
||||
driver: AsyncDriver,
|
||||
driver: GraphDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str] | None = None,
|
||||
|
|
@ -123,7 +94,7 @@ async def retrieve_episodes(
|
|||
Retrieve the last n episodic nodes from the graph.
|
||||
|
||||
Args:
|
||||
driver (AsyncDriver): The Neo4j driver instance.
|
||||
driver (Driver): The Neo4j driver instance.
|
||||
reference_time (datetime): The reference time to filter episodes. Only episodes with a valid_at timestamp
|
||||
less than or equal to this reference_time will be retrieved. This allows for
|
||||
querying the graph's state at a specific point in time.
|
||||
|
|
@ -140,8 +111,8 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
+ """
|
||||
|
|
@ -157,8 +128,7 @@ async def retrieve_episodes(
|
|||
LIMIT $num_episodes
|
||||
"""
|
||||
)
|
||||
|
||||
result = await driver.execute_query(
|
||||
result, _, _ = await driver.execute_query(
|
||||
query,
|
||||
reference_time=reference_time,
|
||||
source=source.name if source is not None else None,
|
||||
|
|
@ -166,6 +136,7 @@ async def retrieve_episodes(
|
|||
group_ids=group_ids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record['content'],
|
||||
|
|
@ -179,6 +150,6 @@ async def retrieve_episodes(
|
|||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
for record in result.records
|
||||
for record in result
|
||||
]
|
||||
return list(reversed(episodes)) # Return in chronological order
|
||||
|
|
|
|||
|
|
@ -326,7 +326,6 @@ async def extract_attributes_from_nodes(
|
|||
) -> list[EntityNode]:
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
||||
updated_nodes: list[EntityNode] = await semaphore_gather(
|
||||
*[
|
||||
extract_attributes_from_node(
|
||||
|
|
|
|||
46
poetry.lock
generated
46
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
|
|
@ -332,12 +332,12 @@ version = "5.0.1"
|
|||
description = "Timeout context manager for asyncio programs"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"},
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
]
|
||||
markers = {main = "python_full_version < \"3.11.3\"", dev = "python_version == \"3.10\""}
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
|
|
@ -759,7 +759,7 @@ description = "Backport of PEP 654 (exception groups)"
|
|||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||
|
|
@ -798,6 +798,20 @@ files = [
|
|||
[package.extras]
|
||||
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""]
|
||||
|
||||
[[package]]
|
||||
name = "falkordb"
|
||||
version = "1.1.2"
|
||||
description = "Python client for interacting with FalkorDB database"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "falkordb-1.1.2.tar.gz", hash = "sha256:db76c97efe14a56c3d65c61b966a42b874e1c78a8fb6808de3f61f4314b04023"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
redis = ">=5.0.1,<6.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "fastjsonschema"
|
||||
version = "2.21.1"
|
||||
|
|
@ -2665,7 +2679,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
|
|||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
markers = "platform_python_implementation != \"PyPy\""
|
||||
files = [
|
||||
{file = "orjson-3.10.16-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4cb473b8e79154fa778fb56d2d73763d977be3dcc140587e07dbc545bbfc38f8"},
|
||||
{file = "orjson-3.10.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:622a8e85eeec1948690409a19ca1c7d9fd8ff116f4861d261e6ae2094fe59a00"},
|
||||
|
|
@ -3691,6 +3704,25 @@ files = [
|
|||
[package.dependencies]
|
||||
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.2.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
|
||||
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=3.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.36.2"
|
||||
|
|
@ -4498,7 +4530,7 @@ description = "A lil' TOML parser"
|
|||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||
|
|
@ -5356,4 +5388,4 @@ groq = ["groq"]
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4"
|
||||
content-hash = "814d067fd2959bfe2db58a22637d86580b66d96f34c433852c67d02089d750ab"
|
||||
content-hash = "2e02a10a6493f7564b86d5d0d09b4cf718004808e115af39550b9ee87c296fb4"
|
||||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"tenacity>=9.0.0",
|
||||
"numpy>=1.0.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"falkordb (>=1.1.2,<2.0.0)",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue