FalkorDB Integration: Bug Fixes and Unit Tests (#607)
* fixes-and-tests * update-workflow * lint-fixes * mypy-fixes * fix-falkor-tests * Update poetry.lock after pyproject.toml changes * update-yml * fix-tests * comp-tests * typo * fix-tests --------- Co-authored-by: Guy Korland <gkorland@gmail.com>
This commit is contained in:
parent
19772aa5a1
commit
6e6115c134
17 changed files with 2301 additions and 1373 deletions
18
.github/workflows/unit_tests.yml
vendored
18
.github/workflows/unit_tests.yml
vendored
|
|
@ -11,6 +11,12 @@ jobs:
|
|||
runs-on: depot-ubuntu-22.04
|
||||
environment:
|
||||
name: development
|
||||
services:
|
||||
falkordb:
|
||||
image: falkordb/falkordb:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
|
|
@ -21,6 +27,8 @@ jobs:
|
|||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
version: "latest"
|
||||
- name: Install redis-cli for FalkorDB health check
|
||||
run: sudo apt-get update && sudo apt-get install -y redis-tools
|
||||
- name: Install dependencies
|
||||
run: uv sync --extra dev
|
||||
- name: Run non-integration tests
|
||||
|
|
@ -28,3 +36,13 @@ jobs:
|
|||
PYTHONPATH: ${{ github.workspace }}
|
||||
run: |
|
||||
uv run pytest -m "not integration"
|
||||
- name: Wait for FalkorDB
|
||||
run: |
|
||||
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
|
||||
- name: Run FalkorDB integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
FALKORDB_HOST: localhost
|
||||
FALKORDB_PORT: 6379
|
||||
run: |
|
||||
uv run pytest tests/driver/test_falkordb_driver.py
|
||||
|
|
|
|||
|
|
@ -120,6 +120,12 @@ Optional:
|
|||
> [!TIP]
|
||||
> The simplest way to install Neo4j is via [Neo4j Desktop](https://neo4j.com/download/). It provides a user-friendly
|
||||
> interface to manage Neo4j instances and databases.
|
||||
> Alternatively, you can use FalkorDB on-premises via Docker and instantly start with the quickstart example:
|
||||
```bash
|
||||
docker run -p 6379:6379 -p 3000:3000 -it --rm falkordb/falkordb:latest
|
||||
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
pip install graphiti-core
|
||||
|
|
@ -156,7 +162,7 @@ pip install graphiti-core[anthropic,groq,google-genai]
|
|||
|
||||
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
|
||||
|
||||
1. Connecting to a Neo4j database
|
||||
1. Connecting to a Neo4j or FalkorDB database
|
||||
2. Initializing Graphiti indices and constraints
|
||||
3. Adding episodes to the graph (both text and structured JSON)
|
||||
4. Searching for relationships (edges) using hybrid search
|
||||
|
|
|
|||
|
|
@ -46,14 +46,20 @@ logger = logging.getLogger(__name__)
|
|||
load_dotenv()
|
||||
|
||||
# FalkorDB connection parameters
|
||||
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
|
||||
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
|
||||
falkor_user = os.environ.get('FALKORDB_USER', 'falkor')
|
||||
falkor_password = os.environ.get('FALKORDB_PASSWORD', '')
|
||||
|
||||
if not falkor_uri:
|
||||
raise ValueError('FALKORDB_URI must be set')
|
||||
# Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
|
||||
# By default, FalkorDB does not require a username or password,
|
||||
# but you can set them via environment variables for added security.
|
||||
#
|
||||
# If you're using FalkorDB Cloud, set the environment variables accordingly.
|
||||
# For on-premises use, you can leave them as None or set them to your preferred values.
|
||||
#
|
||||
# The default host and port are 'localhost' and '6379', respectively.
|
||||
# You can override these values in your environment variables or directly in the code.
|
||||
|
||||
falkor_username = os.environ.get('FALKORDB_USERNAME', None)
|
||||
falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
|
||||
falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
|
||||
falkor_port = os.environ.get('FALKORDB_PORT', '6379')
|
||||
|
||||
async def main():
|
||||
#################################################
|
||||
|
|
@ -65,8 +71,8 @@ async def main():
|
|||
#################################################
|
||||
|
||||
# Initialize Graphiti with FalkorDB connection
|
||||
falkor_driver = FalkorDriver(uri=falkor_uri, user=falkor_user, password=falkor_password)
|
||||
graphiti = Graphiti(uri=falkor_uri, graph_driver=falkor_driver)
|
||||
falkor_driver = FalkorDriver(host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password)
|
||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
try:
|
||||
# Initialize the graph database with graphiti's indices. This only needs to be done once.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -66,22 +65,30 @@ class FalkorDriver(GraphDriver):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
user: str,
|
||||
password: str,
|
||||
host: str = 'localhost',
|
||||
port: str = '6379',
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
falkor_db: FalkorDB | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
uri_parts = uri.split('://', 1)
|
||||
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
||||
"""
|
||||
Initialize the FalkorDB driver.
|
||||
|
||||
self.client = FalkorDB(
|
||||
host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
|
||||
)
|
||||
FalkorDB is a multi-tenant graph database.
|
||||
To connect, provide the host and port.
|
||||
The default parameters assume a local (on-premises) FalkorDB instance.
|
||||
"""
|
||||
super().__init__()
|
||||
if falkor_db is not None:
|
||||
# If a FalkorDB instance is provided, use it directly
|
||||
self.client = falkor_db
|
||||
else:
|
||||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
|
||||
if graph_name is None:
|
||||
graph_name = 'DEFAULT_DATABASE'
|
||||
graph_name = DEFAULT_DATABASE
|
||||
return self.client.select_graph(graph_name)
|
||||
|
||||
async def execute_query(self, cypher_query_, **kwargs: Any):
|
||||
|
|
@ -102,17 +109,36 @@ class FalkorDriver(GraphDriver):
|
|||
raise
|
||||
|
||||
# Convert the result header to a list of strings
|
||||
header = [h[1].decode('utf-8') for h in result.header]
|
||||
return result.result_set, header, None
|
||||
header = [h[1] for h in result.header]
|
||||
|
||||
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
|
||||
records = []
|
||||
for row in result.result_set:
|
||||
record = {}
|
||||
for i, field_name in enumerate(header):
|
||||
if i < len(row):
|
||||
record[field_name] = row[i]
|
||||
else:
|
||||
# If there are more fields in header than values in row, set to None
|
||||
record[field_name] = None
|
||||
records.append(record)
|
||||
|
||||
return records, header, None
|
||||
|
||||
def session(self, database: str | None) -> GraphDriverSession:
|
||||
return FalkorDriverSession(self._get_graph(database))
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.client.connection.close()
|
||||
"""Close the driver connection."""
|
||||
if hasattr(self.client, 'aclose'):
|
||||
await self.client.aclose()
|
||||
elif hasattr(self.client.connection, 'aclose'):
|
||||
await self.client.connection.aclose()
|
||||
elif hasattr(self.client.connection, 'close'):
|
||||
await self.client.connection.close()
|
||||
|
||||
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
||||
return self.execute_query(
|
||||
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
|
||||
await self.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_=database_,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class AddEpisodeResults(BaseModel):
|
|||
class Graphiti:
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
uri: str | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
|
|
@ -162,7 +162,12 @@ class Graphiti:
|
|||
Graphiti if you're using the default OpenAIClient.
|
||||
"""
|
||||
|
||||
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
|
||||
if graph_driver:
|
||||
self.driver = graph_driver
|
||||
else:
|
||||
if uri is None:
|
||||
raise ValueError("uri must be provided when graph_driver is None")
|
||||
self.driver = Neo4jDriver(uri, user, password)
|
||||
|
||||
self.database = DEFAULT_DATABASE
|
||||
self.store_raw_episode_content = store_raw_episode_content
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from graphiti_core.errors import GroupIdValidationError
|
|||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
||||
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
|
||||
|
|
|
|||
|
|
@ -540,10 +540,18 @@ class CommunityNode(Node):
|
|||
|
||||
# Node helpers
|
||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||
created_at = parse_db_date(record['created_at'])
|
||||
valid_at = parse_db_date(record['valid_at'])
|
||||
|
||||
if created_at is None:
|
||||
raise ValueError(f"created_at cannot be None for episode {record.get('uuid', 'unknown')}")
|
||||
if valid_at is None:
|
||||
raise ValueError(f"valid_at cannot be None for episode {record.get('uuid', 'unknown')}")
|
||||
|
||||
return EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=parse_db_date(record['created_at']), # type: ignore
|
||||
valid_at=parse_db_date(record['valid_at']), # type: ignore
|
||||
created_at=created_at,
|
||||
valid_at=valid_at,
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
|
|
|
|||
|
|
@ -278,9 +278,6 @@ async def edge_similarity_search(
|
|||
routing_='r',
|
||||
)
|
||||
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
|
@ -377,8 +374,6 @@ async def node_fulltext_search(
|
|||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -433,8 +428,7 @@ async def node_similarity_search(
|
|||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
records = [dict(zip(header, row, strict=True)) for row in records]
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from typing_extensions import LiteralString
|
|||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
|
@ -140,10 +140,8 @@ async def retrieve_episodes(
|
|||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=datetime.fromtimestamp(
|
||||
record['created_at'].to_native().timestamp(), timezone.utc
|
||||
),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
created_at=parse_db_date(record['created_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
||||
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
|
|
|
|||
2835
poetry.lock
generated
2835
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -18,6 +18,7 @@ dependencies = [
|
|||
"tenacity>=9.0.0",
|
||||
"numpy>=1.0.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"falkordb>=1.1.2,<2.0.0",
|
||||
"posthog>=3.0.0",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
[pytest]
|
||||
markers =
|
||||
integration: marks tests as integration tests
|
||||
integration: marks tests as integration tests
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
1
tests/driver/__init__.py
Normal file
1
tests/driver/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Tests for database drivers."""
|
||||
386
tests/driver/test_falkordb_driver.py
Normal file
386
tests/driver/test_falkordb_driver.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
|
||||
|
||||
class TestFalkorDriver:
|
||||
"""Comprehensive test suite for FalkorDB driver."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_client = MagicMock()
|
||||
with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
|
||||
self.driver = FalkorDriver()
|
||||
self.driver.client = self.mock_client
|
||||
|
||||
def test_init_with_connection_params(self):
|
||||
"""Test initialization with connection parameters."""
|
||||
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
|
||||
driver = FalkorDriver(
|
||||
host='test-host',
|
||||
port='1234',
|
||||
username='test-user',
|
||||
password='test-pass'
|
||||
)
|
||||
assert driver.provider == 'falkordb'
|
||||
mock_falkor_db.assert_called_once_with(
|
||||
host='test-host',
|
||||
port='1234',
|
||||
username='test-user',
|
||||
password='test-pass'
|
||||
)
|
||||
|
||||
def test_init_with_falkor_db_instance(self):
|
||||
"""Test initialization with a FalkorDB instance."""
|
||||
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
|
||||
mock_falkor_db = MagicMock()
|
||||
driver = FalkorDriver(falkor_db=mock_falkor_db)
|
||||
assert driver.provider == 'falkordb'
|
||||
assert driver.client is mock_falkor_db
|
||||
mock_falkor_db_class.assert_not_called()
|
||||
|
||||
def test_provider(self):
|
||||
"""Test driver provider identification."""
|
||||
assert self.driver.provider == 'falkordb'
|
||||
|
||||
def test_get_graph_with_name(self):
|
||||
"""Test _get_graph with specific graph name."""
|
||||
mock_graph = MagicMock()
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
result = self.driver._get_graph('test_graph')
|
||||
|
||||
self.mock_client.select_graph.assert_called_once_with('test_graph')
|
||||
assert result is mock_graph
|
||||
|
||||
def test_get_graph_with_none_defaults_to_default_database(self):
|
||||
"""Test _get_graph with None defaults to DEFAULT_DATABASE."""
|
||||
mock_graph = MagicMock()
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
result = self.driver._get_graph(None)
|
||||
|
||||
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||
assert result is mock_graph
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_success(self):
|
||||
"""Test successful query execution."""
|
||||
mock_graph = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
|
||||
mock_result.result_set = [['row1col1', 'row1col2']]
|
||||
mock_graph.query = AsyncMock(return_value=mock_result)
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
result = await self.driver.execute_query(
|
||||
'MATCH (n) RETURN n',
|
||||
param1='value1',
|
||||
database_='test_db'
|
||||
)
|
||||
|
||||
self.mock_client.select_graph.assert_called_once_with('test_db')
|
||||
mock_graph.query.assert_called_once_with(
|
||||
'MATCH (n) RETURN n',
|
||||
{'param1': 'value1'}
|
||||
)
|
||||
|
||||
result_set, header, summary = result
|
||||
assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
|
||||
assert header == ['column1', 'column2']
|
||||
assert summary is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_handles_index_already_exists_error(self):
|
||||
"""Test handling of 'already indexed' error."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
|
||||
result = await self.driver.execute_query('CREATE INDEX ...')
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_propagates_other_exceptions(self):
|
||||
"""Test that other exceptions are properly propagated."""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
|
||||
with pytest.raises(Exception, match='Other error'):
|
||||
await self.driver.execute_query('INVALID QUERY')
|
||||
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_converts_datetime_parameters(self):
|
||||
"""Test that datetime objects in kwargs are converted to ISO strings."""
|
||||
mock_graph = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.header = []
|
||||
mock_result.result_set = []
|
||||
mock_graph.query = AsyncMock(return_value=mock_result)
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await self.driver.execute_query(
|
||||
'CREATE (n:Node) SET n.created_at = $created_at',
|
||||
created_at=test_datetime
|
||||
)
|
||||
|
||||
call_args = mock_graph.query.call_args[0]
|
||||
assert call_args[1]['created_at'] == test_datetime.isoformat()
|
||||
|
||||
def test_session_creation(self):
|
||||
"""Test session creation with specific database."""
|
||||
mock_graph = MagicMock()
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
session = self.driver.session('test_db')
|
||||
|
||||
assert isinstance(session, FalkorDriverSession)
|
||||
assert session.graph is mock_graph
|
||||
self.mock_client.select_graph.assert_called_once_with('test_db')
|
||||
|
||||
def test_session_creation_with_none_uses_default_database(self):
|
||||
"""Test session creation with None uses default database."""
|
||||
mock_graph = MagicMock()
|
||||
self.mock_client.select_graph.return_value = mock_graph
|
||||
|
||||
session = self.driver.session(None)
|
||||
|
||||
assert isinstance(session, FalkorDriverSession)
|
||||
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_calls_connection_close(self):
|
||||
"""Test driver close method calls connection close."""
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.close = AsyncMock()
|
||||
self.mock_client.connection = mock_connection
|
||||
|
||||
# Ensure hasattr checks work correctly
|
||||
del self.mock_client.aclose # Remove aclose if it exists
|
||||
|
||||
with patch('builtins.hasattr') as mock_hasattr:
|
||||
# hasattr(self.client, 'aclose') returns False
|
||||
# hasattr(self.client.connection, 'aclose') returns False
|
||||
# hasattr(self.client.connection, 'close') returns True
|
||||
mock_hasattr.side_effect = lambda obj, attr: (
|
||||
attr == 'close' and obj is mock_connection
|
||||
)
|
||||
|
||||
await self.driver.close()
|
||||
|
||||
mock_connection.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_all_indexes(self):
|
||||
"""Test delete_all_indexes method."""
|
||||
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
|
||||
await self.driver.delete_all_indexes('test_db')
|
||||
|
||||
mock_execute.assert_called_once_with(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
database_='test_db'
|
||||
)
|
||||
|
||||
|
||||
class TestFalkorDriverSession:
|
||||
"""Test FalkorDB driver session functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_graph = MagicMock()
|
||||
self.session = FalkorDriverSession(self.mock_graph)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_async_context_manager(self):
|
||||
"""Test session can be used as async context manager."""
|
||||
async with self.session as s:
|
||||
assert s is self.session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_method(self):
|
||||
"""Test session close method doesn't raise exceptions."""
|
||||
await self.session.close() # Should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_write_passes_session_and_args(self):
|
||||
"""Test execute_write method passes session and arguments correctly."""
|
||||
async def test_func(session, *args, **kwargs):
|
||||
assert session is self.session
|
||||
assert args == ('arg1', 'arg2')
|
||||
assert kwargs == {'key': 'value'}
|
||||
return 'result'
|
||||
|
||||
result = await self.session.execute_write(
|
||||
test_func, 'arg1', 'arg2', key='value'
|
||||
)
|
||||
assert result == 'result'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_single_query_with_parameters(self):
|
||||
"""Test running a single query with parameters."""
|
||||
self.mock_graph.query = AsyncMock()
|
||||
|
||||
await self.session.run(
|
||||
'MATCH (n) RETURN n',
|
||||
param1='value1',
|
||||
param2='value2'
|
||||
)
|
||||
|
||||
self.mock_graph.query.assert_called_once_with(
|
||||
'MATCH (n) RETURN n',
|
||||
{'param1': 'value1', 'param2': 'value2'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_queries_as_list(self):
|
||||
"""Test running multiple queries passed as list."""
|
||||
self.mock_graph.query = AsyncMock()
|
||||
|
||||
queries = [
|
||||
('MATCH (n) RETURN n', {'param1': 'value1'}),
|
||||
('CREATE (n:Node)', {'param2': 'value2'})
|
||||
]
|
||||
|
||||
await self.session.run(queries)
|
||||
|
||||
assert self.mock_graph.query.call_count == 2
|
||||
calls = self.mock_graph.query.call_args_list
|
||||
assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
|
||||
assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_converts_datetime_objects_to_iso_strings(self):
|
||||
"""Test that datetime objects are converted to ISO strings."""
|
||||
self.mock_graph.query = AsyncMock()
|
||||
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await self.session.run(
|
||||
'CREATE (n:Node) SET n.created_at = $created_at',
|
||||
created_at=test_datetime
|
||||
)
|
||||
|
||||
self.mock_graph.query.assert_called_once()
|
||||
call_args = self.mock_graph.query.call_args[0]
|
||||
assert call_args[1]['created_at'] == test_datetime.isoformat()
|
||||
|
||||
|
||||
class TestDatetimeConversion:
|
||||
"""Test datetime conversion utility function."""
|
||||
|
||||
def test_convert_datetime_dict(self):
|
||||
"""Test datetime conversion in nested dictionary."""
|
||||
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||
|
||||
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
input_dict = {
|
||||
'string_val': 'test',
|
||||
'datetime_val': test_datetime,
|
||||
'nested_dict': {
|
||||
'nested_datetime': test_datetime,
|
||||
'nested_string': 'nested_test'
|
||||
}
|
||||
}
|
||||
|
||||
result = convert_datetimes_to_strings(input_dict)
|
||||
|
||||
assert result['string_val'] == 'test'
|
||||
assert result['datetime_val'] == test_datetime.isoformat()
|
||||
assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
|
||||
assert result['nested_dict']['nested_string'] == 'nested_test'
|
||||
|
||||
def test_convert_datetime_list_and_tuple(self):
|
||||
"""Test datetime conversion in lists and tuples."""
|
||||
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||
|
||||
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# Test list
|
||||
input_list = ['test', test_datetime, ['nested', test_datetime]]
|
||||
result_list = convert_datetimes_to_strings(input_list)
|
||||
assert result_list[0] == 'test'
|
||||
assert result_list[1] == test_datetime.isoformat()
|
||||
assert result_list[2][1] == test_datetime.isoformat()
|
||||
|
||||
# Test tuple
|
||||
input_tuple = ('test', test_datetime)
|
||||
result_tuple = convert_datetimes_to_strings(input_tuple)
|
||||
assert isinstance(result_tuple, tuple)
|
||||
assert result_tuple[0] == 'test'
|
||||
assert result_tuple[1] == test_datetime.isoformat()
|
||||
|
||||
def test_convert_single_datetime(self):
|
||||
"""Test datetime conversion for single datetime object."""
|
||||
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||
|
||||
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
result = convert_datetimes_to_strings(test_datetime)
|
||||
assert result == test_datetime.isoformat()
|
||||
|
||||
def test_convert_other_types_unchanged(self):
|
||||
"""Test that non-datetime types are returned unchanged."""
|
||||
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
|
||||
|
||||
assert convert_datetimes_to_strings('string') == 'string'
|
||||
assert convert_datetimes_to_strings(123) == 123
|
||||
assert convert_datetimes_to_strings(None) is None
|
||||
assert convert_datetimes_to_strings(True) is True
|
||||
|
||||
|
||||
# Simple integration test
|
||||
class TestFalkorDriverIntegration:
|
||||
"""Simple integration test for FalkorDB driver."""
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_integration_with_real_falkordb(self):
|
||||
"""Basic integration test with real FalkorDB instance."""
|
||||
pytest.importorskip('falkordb')
|
||||
|
||||
falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
falkor_port = os.getenv('FALKORDB_PORT', '6379')
|
||||
|
||||
try:
|
||||
driver = FalkorDriver(host=falkor_host, port=falkor_port)
|
||||
|
||||
# Test basic query execution
|
||||
result = await driver.execute_query('RETURN 1 as test')
|
||||
assert result is not None
|
||||
|
||||
result_set, header, summary = result
|
||||
assert header == ['test']
|
||||
assert result_set == [{'test': 1}]
|
||||
|
||||
await driver.close()
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"FalkorDB not available for integration test: {e}")
|
||||
|
|
@ -24,7 +24,7 @@ def test_lucene_sanitize():
|
|||
queries = [
|
||||
(
|
||||
'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
|
||||
'\This has every escape character \+ \- \&\& \|\| \! \( \) \{ \} \[ \] \^ \\" \~ \* \? \: \\\ \/',
|
||||
'\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
|
||||
),
|
||||
('this has no escape characters', 'this has no escape characters'),
|
||||
]
|
||||
|
|
|
|||
160
tests/test_graphiti_falkordb_int.py
Normal file
160
tests/test_graphiti_falkordb_int.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphiti_falkordb_init():
|
||||
logger = setup_logging()
|
||||
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=FALKORDB_PORT,
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
results = await graphiti.search_(query='Who is the user?')
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
||||
logger.info(pretty_results)
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_falkordb_integration():
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=FALKORDB_PORT,
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
client = Graphiti(graph_driver=falkor_driver)
|
||||
embedder = client.embedder
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
episode = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='message',
|
||||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
entity_edges=[],
|
||||
)
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
)
|
||||
|
||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||
|
||||
episodic_edge_1 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
episodic_edge_2 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name='likes',
|
||||
fact='Alice likes Bob',
|
||||
episodes=[],
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_at=now,
|
||||
)
|
||||
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
||||
# test save
|
||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
||||
|
||||
# test get
|
||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||
|
||||
# test delete
|
||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
||||
|
||||
await client.close()
|
||||
137
tests/test_node_falkordb_int.py
Normal file
137
tests/test_node_falkordb_int.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodeType,
|
||||
EpisodicNode,
|
||||
)
|
||||
|
||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entity_node():
|
||||
return EntityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id='test_group',
|
||||
labels=['Entity'],
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_episodic_node():
|
||||
return EpisodicNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Episode 1',
|
||||
group_id='test_group',
|
||||
source=EpisodeType.text,
|
||||
source_description='Test source',
|
||||
content='Some content here',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_community_node():
|
||||
return CommunityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Community A',
|
||||
name_embedding=[0.5] * 1024,
|
||||
group_id='test_group',
|
||||
summary='Community summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=FALKORDB_PORT,
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_entity_node.save(falkor_driver)
|
||||
|
||||
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
|
||||
assert retrieved.uuid == sample_entity_node.uuid
|
||||
assert retrieved.name == 'Test Entity'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
|
||||
await sample_entity_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=FALKORDB_PORT,
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_community_node.save(falkor_driver)
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
|
||||
assert retrieved.uuid == sample_community_node.uuid
|
||||
assert retrieved.name == 'Community A'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.summary == 'Community summary'
|
||||
|
||||
await sample_community_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=FALKORDB_PORT,
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_episodic_node.save(falkor_driver)
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
|
||||
assert retrieved.uuid == sample_episodic_node.uuid
|
||||
assert retrieved.name == 'Episode 1'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.source == EpisodeType.text
|
||||
assert retrieved.source_description == 'Test source'
|
||||
assert retrieved.content == 'Some content here'
|
||||
|
||||
await sample_episodic_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
Loading…
Add table
Reference in a new issue