FalkorDB Integration: Bug Fixes and Unit Tests (#607)
* fixes-and-tests * update-workflow * lint-fixes * mypy-fixes * fix-falkor-tests * Update poetry.lock after pyproject.toml changes * update-yml * fix-tests * comp-tests * typo * fix-tests --------- Co-authored-by: Guy Korland <gkorland@gmail.com>
This commit is contained in:
parent
19772aa5a1
commit
6e6115c134
17 changed files with 2301 additions and 1373 deletions
18
.github/workflows/unit_tests.yml
vendored
18
.github/workflows/unit_tests.yml
vendored
|
|
@ -11,6 +11,12 @@ jobs:
|
||||||
runs-on: depot-ubuntu-22.04
|
runs-on: depot-ubuntu-22.04
|
||||||
environment:
|
environment:
|
||||||
name: development
|
name: development
|
||||||
|
services:
|
||||||
|
falkordb:
|
||||||
|
image: falkordb/falkordb:latest
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
||||||
|
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|
@ -21,6 +27,8 @@ jobs:
|
||||||
uses: astral-sh/setup-uv@v3
|
uses: astral-sh/setup-uv@v3
|
||||||
with:
|
with:
|
||||||
version: "latest"
|
version: "latest"
|
||||||
|
- name: Install redis-cli for FalkorDB health check
|
||||||
|
run: sudo apt-get update && sudo apt-get install -y redis-tools
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --extra dev
|
||||||
- name: Run non-integration tests
|
- name: Run non-integration tests
|
||||||
|
|
@ -28,3 +36,13 @@ jobs:
|
||||||
PYTHONPATH: ${{ github.workspace }}
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -m "not integration"
|
uv run pytest -m "not integration"
|
||||||
|
- name: Wait for FalkorDB
|
||||||
|
run: |
|
||||||
|
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
|
||||||
|
- name: Run FalkorDB integration tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
|
FALKORDB_HOST: localhost
|
||||||
|
FALKORDB_PORT: 6379
|
||||||
|
run: |
|
||||||
|
uv run pytest tests/driver/test_falkordb_driver.py
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,12 @@ Optional:
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> The simplest way to install Neo4j is via [Neo4j Desktop](https://neo4j.com/download/). It provides a user-friendly
|
> The simplest way to install Neo4j is via [Neo4j Desktop](https://neo4j.com/download/). It provides a user-friendly
|
||||||
> interface to manage Neo4j instances and databases.
|
> interface to manage Neo4j instances and databases.
|
||||||
|
> Alternatively, you can use FalkorDB on-premises via Docker and instantly start with the quickstart example:
|
||||||
|
```bash
|
||||||
|
docker run -p 6379:6379 -p 3000:3000 -it --rm falkordb/falkordb:latest
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install graphiti-core
|
pip install graphiti-core
|
||||||
|
|
@ -156,7 +162,7 @@ pip install graphiti-core[anthropic,groq,google-genai]
|
||||||
|
|
||||||
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
||||||
|
|
||||||
1. Connecting to a Neo4j database
|
1. Connecting to a Neo4j or FalkorDB database
|
||||||
2. Initializing Graphiti indices and constraints
|
2. Initializing Graphiti indices and constraints
|
||||||
3. Adding episodes to the graph (both text and structured JSON)
|
3. Adding episodes to the graph (both text and structured JSON)
|
||||||
4. Searching for relationships (edges) using hybrid search
|
4. Searching for relationships (edges) using hybrid search
|
||||||
|
|
|
||||||
|
|
@ -46,14 +46,20 @@ logger = logging.getLogger(__name__)
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# FalkorDB connection parameters
|
# FalkorDB connection parameters
|
||||||
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
|
# Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
|
||||||
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
|
# By default, FalkorDB does not require a username or password,
|
||||||
falkor_user = os.environ.get('FALKORDB_USER', 'falkor')
|
# but you can set them via environment variables for added security.
|
||||||
falkor_password = os.environ.get('FALKORDB_PASSWORD', '')
|
#
|
||||||
|
# If you're using FalkorDB Cloud, set the environment variables accordingly.
|
||||||
if not falkor_uri:
|
# For on-premises use, you can leave them as None or set them to your preferred values.
|
||||||
raise ValueError('FALKORDB_URI must be set')
|
#
|
||||||
|
# The default host and port are 'localhost' and '6379', respectively.
|
||||||
|
# You can override these values in your environment variables or directly in the code.
|
||||||
|
|
||||||
|
falkor_username = os.environ.get('FALKORDB_USERNAME', None)
|
||||||
|
falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
|
||||||
|
falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
|
||||||
|
falkor_port = os.environ.get('FALKORDB_PORT', '6379')
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
#################################################
|
#################################################
|
||||||
|
|
@ -65,8 +71,8 @@ async def main():
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
# Initialize Graphiti with FalkorDB connection
|
# Initialize Graphiti with FalkorDB connection
|
||||||
falkor_driver = FalkorDriver(uri=falkor_uri, user=falkor_user, password=falkor_password)
|
falkor_driver = FalkorDriver(host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password)
|
||||||
graphiti = Graphiti(uri=falkor_uri, graph_driver=falkor_driver)
|
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Coroutine
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -66,22 +65,30 @@ class FalkorDriver(GraphDriver):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
host: str = 'localhost',
|
||||||
user: str,
|
port: str = '6379',
|
||||||
password: str,
|
username: str | None = None,
|
||||||
|
password: str | None = None,
|
||||||
|
falkor_db: FalkorDB | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
"""
|
||||||
uri_parts = uri.split('://', 1)
|
Initialize the FalkorDB driver.
|
||||||
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
|
||||||
|
|
||||||
self.client = FalkorDB(
|
FalkorDB is a multi-tenant graph database.
|
||||||
host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
|
To connect, provide the host and port.
|
||||||
)
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if falkor_db is not None:
|
||||||
|
# If a FalkorDB instance is provided, use it directly
|
||||||
|
self.client = falkor_db
|
||||||
|
else:
|
||||||
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||||
|
|
||||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
|
||||||
if graph_name is None:
|
if graph_name is None:
|
||||||
graph_name = 'DEFAULT_DATABASE'
|
graph_name = DEFAULT_DATABASE
|
||||||
return self.client.select_graph(graph_name)
|
return self.client.select_graph(graph_name)
|
||||||
|
|
||||||
async def execute_query(self, cypher_query_, **kwargs: Any):
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
||||||
|
|
@ -102,17 +109,36 @@ class FalkorDriver(GraphDriver):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert the result header to a list of strings
|
# Convert the result header to a list of strings
|
||||||
header = [h[1].decode('utf-8') for h in result.header]
|
header = [h[1] for h in result.header]
|
||||||
return result.result_set, header, None
|
|
||||||
|
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
|
||||||
|
records = []
|
||||||
|
for row in result.result_set:
|
||||||
|
record = {}
|
||||||
|
for i, field_name in enumerate(header):
|
||||||
|
if i < len(row):
|
||||||
|
record[field_name] = row[i]
|
||||||
|
else:
|
||||||
|
# If there are more fields in header than values in row, set to None
|
||||||
|
record[field_name] = None
|
||||||
|
records.append(record)
|
||||||
|
|
||||||
|
return records, header, None
|
||||||
|
|
||||||
def session(self, database: str | None) -> GraphDriverSession:
|
def session(self, database: str | None) -> GraphDriverSession:
|
||||||
return FalkorDriverSession(self._get_graph(database))
|
return FalkorDriverSession(self._get_graph(database))
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.client.connection.close()
|
"""Close the driver connection."""
|
||||||
|
if hasattr(self.client, 'aclose'):
|
||||||
|
await self.client.aclose()
|
||||||
|
elif hasattr(self.client.connection, 'aclose'):
|
||||||
|
await self.client.connection.aclose()
|
||||||
|
elif hasattr(self.client.connection, 'close'):
|
||||||
|
await self.client.connection.close()
|
||||||
|
|
||||||
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
|
||||||
return self.execute_query(
|
await self.execute_query(
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
database_=database_,
|
database_=database_,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ class AddEpisodeResults(BaseModel):
|
||||||
class Graphiti:
|
class Graphiti:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
uri: str | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
password: str | None = None,
|
password: str | None = None,
|
||||||
llm_client: LLMClient | None = None,
|
llm_client: LLMClient | None = None,
|
||||||
|
|
@ -162,7 +162,12 @@ class Graphiti:
|
||||||
Graphiti if you're using the default OpenAIClient.
|
Graphiti if you're using the default OpenAIClient.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
if graph_driver:
|
||||||
|
self.driver = graph_driver
|
||||||
|
else:
|
||||||
|
if uri is None:
|
||||||
|
raise ValueError("uri must be provided when graph_driver is None")
|
||||||
|
self.driver = Neo4jDriver(uri, user, password)
|
||||||
|
|
||||||
self.database = DEFAULT_DATABASE
|
self.database = DEFAULT_DATABASE
|
||||||
self.store_raw_episode_content = store_raw_episode_content
|
self.store_raw_episode_content = store_raw_episode_content
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ from graphiti_core.errors import GroupIdValidationError
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
|
||||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
||||||
|
|
|
||||||
|
|
@ -540,10 +540,18 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
# Node helpers
|
# Node helpers
|
||||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||||
|
created_at = parse_db_date(record['created_at'])
|
||||||
|
valid_at = parse_db_date(record['valid_at'])
|
||||||
|
|
||||||
|
if created_at is None:
|
||||||
|
raise ValueError(f"created_at cannot be None for episode {record.get('uuid', 'unknown')}")
|
||||||
|
if valid_at is None:
|
||||||
|
raise ValueError(f"valid_at cannot be None for episode {record.get('uuid', 'unknown')}")
|
||||||
|
|
||||||
return EpisodicNode(
|
return EpisodicNode(
|
||||||
content=record['content'],
|
content=record['content'],
|
||||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
created_at=created_at,
|
||||||
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
valid_at=valid_at,
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
source=EpisodeType.from_str(record['source']),
|
source=EpisodeType.from_str(record['source']),
|
||||||
|
|
|
||||||
|
|
@ -278,9 +278,6 @@ async def edge_similarity_search(
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
if driver.provider == 'falkordb':
|
|
||||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
@ -377,8 +374,6 @@ async def node_fulltext_search(
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
if driver.provider == 'falkordb':
|
|
||||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -433,8 +428,7 @@ async def node_similarity_search(
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
if driver.provider == 'falkordb':
|
|
||||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||||
|
|
||||||
EPISODE_WINDOW_LEN = 3
|
EPISODE_WINDOW_LEN = 3
|
||||||
|
|
@ -140,10 +140,8 @@ async def retrieve_episodes(
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
content=record['content'],
|
content=record['content'],
|
||||||
created_at=datetime.fromtimestamp(
|
created_at=parse_db_date(record['created_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
||||||
record['created_at'].to_native().timestamp(), timezone.utc
|
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
||||||
),
|
|
||||||
valid_at=(record['valid_at'].to_native()),
|
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
group_id=record['group_id'],
|
group_id=record['group_id'],
|
||||||
source=EpisodeType.from_str(record['source']),
|
source=EpisodeType.from_str(record['source']),
|
||||||
|
|
|
||||||
2835
poetry.lock
generated
2835
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -18,6 +18,7 @@ dependencies = [
|
||||||
"tenacity>=9.0.0",
|
"tenacity>=9.0.0",
|
||||||
"numpy>=1.0.0",
|
"numpy>=1.0.0",
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
|
"falkordb>=1.1.2,<2.0.0",
|
||||||
"posthog>=3.0.0",
|
"posthog>=3.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
integration: marks tests as integration tests
|
integration: marks tests as integration tests
|
||||||
|
asyncio_default_fixture_loop_scope = function
|
||||||
1
tests/driver/__init__.py
Normal file
1
tests/driver/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Tests for database drivers."""
|
||||||
386
tests/driver/test_falkordb_driver.py
Normal file
386
tests/driver/test_falkordb_driver.py
Normal file
|
|
@ -0,0 +1,386 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
|
|
||||||
|
|
||||||
|
class TestFalkorDriver:
|
||||||
|
"""Comprehensive test suite for FalkorDB driver."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.mock_client = MagicMock()
|
||||||
|
with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
|
||||||
|
self.driver = FalkorDriver()
|
||||||
|
self.driver.client = self.mock_client
|
||||||
|
|
||||||
|
def test_init_with_connection_params(self):
|
||||||
|
"""Test initialization with connection parameters."""
|
||||||
|
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
|
||||||
|
driver = FalkorDriver(
|
||||||
|
host='test-host',
|
||||||
|
port='1234',
|
||||||
|
username='test-user',
|
||||||
|
password='test-pass'
|
||||||
|
)
|
||||||
|
assert driver.provider == 'falkordb'
|
||||||
|
mock_falkor_db.assert_called_once_with(
|
||||||
|
host='test-host',
|
||||||
|
port='1234',
|
||||||
|
username='test-user',
|
||||||
|
password='test-pass'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_init_with_falkor_db_instance(self):
|
||||||
|
"""Test initialization with a FalkorDB instance."""
|
||||||
|
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
|
||||||
|
mock_falkor_db = MagicMock()
|
||||||
|
driver = FalkorDriver(falkor_db=mock_falkor_db)
|
||||||
|
assert driver.provider == 'falkordb'
|
||||||
|
assert driver.client is mock_falkor_db
|
||||||
|
mock_falkor_db_class.assert_not_called()
|
||||||
|
|
||||||
|
def test_provider(self):
|
||||||
|
"""Test driver provider identification."""
|
||||||
|
assert self.driver.provider == 'falkordb'
|
||||||
|
|
||||||
|
def test_get_graph_with_name(self):
|
||||||
|
"""Test _get_graph with specific graph name."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
result = self.driver._get_graph('test_graph')
|
||||||
|
|
||||||
|
self.mock_client.select_graph.assert_called_once_with('test_graph')
|
||||||
|
assert result is mock_graph
|
||||||
|
|
||||||
|
def test_get_graph_with_none_defaults_to_default_database(self):
|
||||||
|
"""Test _get_graph with None defaults to DEFAULT_DATABASE."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
result = self.driver._get_graph(None)
|
||||||
|
|
||||||
|
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||||
|
assert result is mock_graph
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_query_success(self):
|
||||||
|
"""Test successful query execution."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
|
||||||
|
mock_result.result_set = [['row1col1', 'row1col2']]
|
||||||
|
mock_graph.query = AsyncMock(return_value=mock_result)
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
result = await self.driver.execute_query(
|
||||||
|
'MATCH (n) RETURN n',
|
||||||
|
param1='value1',
|
||||||
|
database_='test_db'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_client.select_graph.assert_called_once_with('test_db')
|
||||||
|
mock_graph.query.assert_called_once_with(
|
||||||
|
'MATCH (n) RETURN n',
|
||||||
|
{'param1': 'value1'}
|
||||||
|
)
|
||||||
|
|
||||||
|
result_set, header, summary = result
|
||||||
|
assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
|
||||||
|
assert header == ['column1', 'column2']
|
||||||
|
assert summary is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_query_handles_index_already_exists_error(self):
|
||||||
|
"""Test handling of 'already indexed' error."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
|
||||||
|
result = await self.driver.execute_query('CREATE INDEX ...')
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_query_propagates_other_exceptions(self):
|
||||||
|
"""Test that other exceptions are properly propagated."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
|
||||||
|
with pytest.raises(Exception, match='Other error'):
|
||||||
|
await self.driver.execute_query('INVALID QUERY')
|
||||||
|
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_query_converts_datetime_parameters(self):
|
||||||
|
"""Test that datetime objects in kwargs are converted to ISO strings."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.header = []
|
||||||
|
mock_result.result_set = []
|
||||||
|
mock_graph.query = AsyncMock(return_value=mock_result)
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
await self.driver.execute_query(
|
||||||
|
'CREATE (n:Node) SET n.created_at = $created_at',
|
||||||
|
created_at=test_datetime
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_graph.query.call_args[0]
|
||||||
|
assert call_args[1]['created_at'] == test_datetime.isoformat()
|
||||||
|
|
||||||
|
def test_session_creation(self):
|
||||||
|
"""Test session creation with specific database."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
session = self.driver.session('test_db')
|
||||||
|
|
||||||
|
assert isinstance(session, FalkorDriverSession)
|
||||||
|
assert session.graph is mock_graph
|
||||||
|
self.mock_client.select_graph.assert_called_once_with('test_db')
|
||||||
|
|
||||||
|
def test_session_creation_with_none_uses_default_database(self):
|
||||||
|
"""Test session creation with None uses default database."""
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
self.mock_client.select_graph.return_value = mock_graph
|
||||||
|
|
||||||
|
session = self.driver.session(None)
|
||||||
|
|
||||||
|
assert isinstance(session, FalkorDriverSession)
|
||||||
|
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_calls_connection_close(self):
|
||||||
|
"""Test driver close method calls connection close."""
|
||||||
|
mock_connection = MagicMock()
|
||||||
|
mock_connection.close = AsyncMock()
|
||||||
|
self.mock_client.connection = mock_connection
|
||||||
|
|
||||||
|
# Ensure hasattr checks work correctly
|
||||||
|
del self.mock_client.aclose # Remove aclose if it exists
|
||||||
|
|
||||||
|
with patch('builtins.hasattr') as mock_hasattr:
|
||||||
|
# hasattr(self.client, 'aclose') returns False
|
||||||
|
# hasattr(self.client.connection, 'aclose') returns False
|
||||||
|
# hasattr(self.client.connection, 'close') returns True
|
||||||
|
mock_hasattr.side_effect = lambda obj, attr: (
|
||||||
|
attr == 'close' and obj is mock_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.driver.close()
|
||||||
|
|
||||||
|
mock_connection.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_all_indexes(self):
|
||||||
|
"""Test delete_all_indexes method."""
|
||||||
|
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
|
||||||
|
await self.driver.delete_all_indexes('test_db')
|
||||||
|
|
||||||
|
mock_execute.assert_called_once_with(
|
||||||
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
|
database_='test_db'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFalkorDriverSession:
|
||||||
|
"""Test FalkorDB driver session functionality."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.mock_graph = MagicMock()
|
||||||
|
self.session = FalkorDriverSession(self.mock_graph)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_async_context_manager(self):
|
||||||
|
"""Test session can be used as async context manager."""
|
||||||
|
async with self.session as s:
|
||||||
|
assert s is self.session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_method(self):
|
||||||
|
"""Test session close method doesn't raise exceptions."""
|
||||||
|
await self.session.close() # Should not raise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_write_passes_session_and_args(self):
|
||||||
|
"""Test execute_write method passes session and arguments correctly."""
|
||||||
|
async def test_func(session, *args, **kwargs):
|
||||||
|
assert session is self.session
|
||||||
|
assert args == ('arg1', 'arg2')
|
||||||
|
assert kwargs == {'key': 'value'}
|
||||||
|
return 'result'
|
||||||
|
|
||||||
|
result = await self.session.execute_write(
|
||||||
|
test_func, 'arg1', 'arg2', key='value'
|
||||||
|
)
|
||||||
|
assert result == 'result'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_single_query_with_parameters(self):
|
||||||
|
"""Test running a single query with parameters."""
|
||||||
|
self.mock_graph.query = AsyncMock()
|
||||||
|
|
||||||
|
await self.session.run(
|
||||||
|
'MATCH (n) RETURN n',
|
||||||
|
param1='value1',
|
||||||
|
param2='value2'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_graph.query.assert_called_once_with(
|
||||||
|
'MATCH (n) RETURN n',
|
||||||
|
{'param1': 'value1', 'param2': 'value2'}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_multiple_queries_as_list(self):
|
||||||
|
"""Test running multiple queries passed as list."""
|
||||||
|
self.mock_graph.query = AsyncMock()
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
('MATCH (n) RETURN n', {'param1': 'value1'}),
|
||||||
|
('CREATE (n:Node)', {'param2': 'value2'})
|
||||||
|
]
|
||||||
|
|
||||||
|
await self.session.run(queries)
|
||||||
|
|
||||||
|
assert self.mock_graph.query.call_count == 2
|
||||||
|
calls = self.mock_graph.query.call_args_list
|
||||||
|
assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
|
||||||
|
assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_converts_datetime_objects_to_iso_strings(self):
|
||||||
|
"""Test that datetime objects are converted to ISO strings."""
|
||||||
|
self.mock_graph.query = AsyncMock()
|
||||||
|
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
await self.session.run(
|
||||||
|
'CREATE (n:Node) SET n.created_at = $created_at',
|
||||||
|
created_at=test_datetime
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_graph.query.assert_called_once()
|
||||||
|
call_args = self.mock_graph.query.call_args[0]
|
||||||
|
assert call_args[1]['created_at'] == test_datetime.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatetimeConversion:
|
||||||
|
"""Test datetime conversion utility function."""
|
||||||
|
|
||||||
|
def test_convert_datetime_dict(self):
|
||||||
|
"""Test datetime conversion in nested dictionary."""
|
||||||
|
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||||
|
|
||||||
|
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
input_dict = {
|
||||||
|
'string_val': 'test',
|
||||||
|
'datetime_val': test_datetime,
|
||||||
|
'nested_dict': {
|
||||||
|
'nested_datetime': test_datetime,
|
||||||
|
'nested_string': 'nested_test'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = convert_datetimes_to_strings(input_dict)
|
||||||
|
|
||||||
|
assert result['string_val'] == 'test'
|
||||||
|
assert result['datetime_val'] == test_datetime.isoformat()
|
||||||
|
assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
|
||||||
|
assert result['nested_dict']['nested_string'] == 'nested_test'
|
||||||
|
|
||||||
|
def test_convert_datetime_list_and_tuple(self):
|
||||||
|
"""Test datetime conversion in lists and tuples."""
|
||||||
|
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||||
|
|
||||||
|
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
# Test list
|
||||||
|
input_list = ['test', test_datetime, ['nested', test_datetime]]
|
||||||
|
result_list = convert_datetimes_to_strings(input_list)
|
||||||
|
assert result_list[0] == 'test'
|
||||||
|
assert result_list[1] == test_datetime.isoformat()
|
||||||
|
assert result_list[2][1] == test_datetime.isoformat()
|
||||||
|
|
||||||
|
# Test tuple
|
||||||
|
input_tuple = ('test', test_datetime)
|
||||||
|
result_tuple = convert_datetimes_to_strings(input_tuple)
|
||||||
|
assert isinstance(result_tuple, tuple)
|
||||||
|
assert result_tuple[0] == 'test'
|
||||||
|
assert result_tuple[1] == test_datetime.isoformat()
|
||||||
|
|
||||||
|
def test_convert_single_datetime(self):
|
||||||
|
"""Test datetime conversion for single datetime object."""
|
||||||
|
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||||
|
|
||||||
|
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||||
|
result = convert_datetimes_to_strings(test_datetime)
|
||||||
|
assert result == test_datetime.isoformat()
|
||||||
|
|
||||||
|
def test_convert_other_types_unchanged(self):
|
||||||
|
"""Test that non-datetime types are returned unchanged."""
|
||||||
|
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||||
|
|
||||||
|
assert convert_datetimes_to_strings('string') == 'string'
|
||||||
|
assert convert_datetimes_to_strings(123) == 123
|
||||||
|
assert convert_datetimes_to_strings(None) is None
|
||||||
|
assert convert_datetimes_to_strings(True) is True
|
||||||
|
|
||||||
|
|
||||||
|
# Simple integration test
|
||||||
|
class TestFalkorDriverIntegration:
|
||||||
|
"""Simple integration test for FalkorDB driver."""
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_integration_with_real_falkordb(self):
|
||||||
|
"""Basic integration test with real FalkorDB instance."""
|
||||||
|
pytest.importorskip('falkordb')
|
||||||
|
|
||||||
|
falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
|
||||||
|
falkor_port = os.getenv('FALKORDB_PORT', '6379')
|
||||||
|
|
||||||
|
try:
|
||||||
|
driver = FalkorDriver(host=falkor_host, port=falkor_port)
|
||||||
|
|
||||||
|
# Test basic query execution
|
||||||
|
result = await driver.execute_query('RETURN 1 as test')
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
result_set, header, summary = result
|
||||||
|
assert header == ['test']
|
||||||
|
assert result_set == [{'test': 1}]
|
||||||
|
|
||||||
|
await driver.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"FalkorDB not available for integration test: {e}")
|
||||||
|
|
@ -24,7 +24,7 @@ def test_lucene_sanitize():
|
||||||
queries = [
|
queries = [
|
||||||
(
|
(
|
||||||
'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
|
'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
|
||||||
'\This has every escape character \+ \- \&\& \|\| \! \( \) \{ \} \[ \] \^ \\" \~ \* \? \: \\\ \/',
|
'\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
|
||||||
),
|
),
|
||||||
('this has no escape characters', 'this has no escape characters'),
|
('this has no escape characters', 'this has no escape characters'),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
160
tests/test_graphiti_falkordb_int.py
Normal file
160
tests/test_graphiti_falkordb_int.py
Normal file
|
|
@ -0,0 +1,160 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||||
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
|
from graphiti_core.graphiti import Graphiti
|
||||||
|
from graphiti_core.helpers import semaphore_gather
|
||||||
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
|
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
pytest_plugins = ('pytest_asyncio',)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||||
|
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||||
|
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||||
|
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
# Create a logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||||
|
|
||||||
|
# Create console handler and set level to INFO
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
# Add formatter to console handler
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add console handler to logger
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graphiti_falkordb_init():
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
falkor_driver = FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=FALKORDB_PORT,
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||||
|
|
||||||
|
results = await graphiti.search_(query='Who is the user?')
|
||||||
|
|
||||||
|
pretty_results = search_results_to_context_string(results)
|
||||||
|
|
||||||
|
logger.info(pretty_results)
|
||||||
|
|
||||||
|
await graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_falkordb_integration():
|
||||||
|
falkor_driver = FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=FALKORDB_PORT,
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Graphiti(graph_driver=falkor_driver)
|
||||||
|
embedder = client.embedder
|
||||||
|
driver = client.driver
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
episode = EpisodicNode(
|
||||||
|
name='test_episode',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source='message',
|
||||||
|
source_description='conversation message',
|
||||||
|
content='Alice likes Bob',
|
||||||
|
entity_edges=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name='Alice',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary='Alice summary',
|
||||||
|
)
|
||||||
|
|
||||||
|
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||||
|
|
||||||
|
episodic_edge_1 = EpisodicEdge(
|
||||||
|
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
episodic_edge_2 = EpisodicEdge(
|
||||||
|
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_edge = EntityEdge(
|
||||||
|
source_node_uuid=alice_node.uuid,
|
||||||
|
target_node_uuid=bob_node.uuid,
|
||||||
|
created_at=now,
|
||||||
|
name='likes',
|
||||||
|
fact='Alice likes Bob',
|
||||||
|
episodes=[],
|
||||||
|
expired_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
invalid_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
await entity_edge.generate_embedding(embedder)
|
||||||
|
|
||||||
|
nodes = [episode, alice_node, bob_node]
|
||||||
|
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||||
|
|
||||||
|
# test save
|
||||||
|
await semaphore_gather(*[node.save(driver) for node in nodes])
|
||||||
|
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
||||||
|
|
||||||
|
# test get
|
||||||
|
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||||
|
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
||||||
|
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
||||||
|
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||||
|
|
||||||
|
# test delete
|
||||||
|
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
||||||
|
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
||||||
|
|
||||||
|
await client.close()
|
||||||
137
tests/test_node_falkordb_int.py
Normal file
137
tests/test_node_falkordb_int.py
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||||
|
from graphiti_core.nodes import (
|
||||||
|
CommunityNode,
|
||||||
|
EntityNode,
|
||||||
|
EpisodeType,
|
||||||
|
EpisodicNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||||
|
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||||
|
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||||
|
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_entity_node():
|
||||||
|
return EntityNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Test Entity',
|
||||||
|
group_id='test_group',
|
||||||
|
labels=['Entity'],
|
||||||
|
name_embedding=[0.5] * 1024,
|
||||||
|
summary='Entity Summary',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_episodic_node():
|
||||||
|
return EpisodicNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Episode 1',
|
||||||
|
group_id='test_group',
|
||||||
|
source=EpisodeType.text,
|
||||||
|
source_description='Test source',
|
||||||
|
content='Some content here',
|
||||||
|
valid_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_community_node():
|
||||||
|
return CommunityNode(
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
name='Community A',
|
||||||
|
name_embedding=[0.5] * 1024,
|
||||||
|
group_id='test_group',
|
||||||
|
summary='Community summary',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||||
|
falkor_driver = FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=FALKORDB_PORT,
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
await sample_entity_node.save(falkor_driver)
|
||||||
|
|
||||||
|
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_entity_node.uuid
|
||||||
|
assert retrieved.name == 'Test Entity'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
|
||||||
|
await sample_entity_node.delete(falkor_driver)
|
||||||
|
await falkor_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||||
|
falkor_driver = FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=FALKORDB_PORT,
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
await sample_community_node.save(falkor_driver)
|
||||||
|
|
||||||
|
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_community_node.uuid
|
||||||
|
assert retrieved.name == 'Community A'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
assert retrieved.summary == 'Community summary'
|
||||||
|
|
||||||
|
await sample_community_node.delete(falkor_driver)
|
||||||
|
await falkor_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||||
|
falkor_driver = FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=FALKORDB_PORT,
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD
|
||||||
|
)
|
||||||
|
|
||||||
|
await sample_episodic_node.save(falkor_driver)
|
||||||
|
|
||||||
|
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
|
||||||
|
assert retrieved.uuid == sample_episodic_node.uuid
|
||||||
|
assert retrieved.name == 'Episode 1'
|
||||||
|
assert retrieved.group_id == 'test_group'
|
||||||
|
assert retrieved.source == EpisodeType.text
|
||||||
|
assert retrieved.source_description == 'Test source'
|
||||||
|
assert retrieved.content == 'Some content here'
|
||||||
|
|
||||||
|
await sample_episodic_node.delete(falkor_driver)
|
||||||
|
await falkor_driver.close()
|
||||||
Loading…
Add table
Reference in a new issue