FalkorDB Integration: Bug Fixes and Unit Tests (#607)

* fixes-and-tests

* update-workflow

* lint-fixes

* mypy-fixes

* fix-falkor-tests

* Update poetry.lock after pyproject.toml changes

* update-yml

* fix-tests

* comp-tests

* typo

* fix-tests

---------

Co-authored-by: Guy Korland <gkorland@gmail.com>
This commit is contained in:
Gal Shubeli 2025-06-30 18:01:44 +03:00 committed by GitHub
parent 19772aa5a1
commit 6e6115c134
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2301 additions and 1373 deletions

View file

@ -11,6 +11,12 @@ jobs:
runs-on: depot-ubuntu-22.04
environment:
name: development
services:
falkordb:
image: falkordb/falkordb:latest
ports:
- 6379:6379
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
@ -21,6 +27,8 @@ jobs:
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Install redis-cli for FalkorDB health check
run: sudo apt-get update && sudo apt-get install -y redis-tools
- name: Install dependencies
run: uv sync --extra dev
- name: Run non-integration tests
@ -28,3 +36,13 @@ jobs:
PYTHONPATH: ${{ github.workspace }}
run: |
uv run pytest -m "not integration"
- name: Wait for FalkorDB
run: |
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
- name: Run FalkorDB integration tests
env:
PYTHONPATH: ${{ github.workspace }}
FALKORDB_HOST: localhost
FALKORDB_PORT: 6379
run: |
uv run pytest tests/driver/test_falkordb_driver.py

View file

@ -120,6 +120,12 @@ Optional:
> [!TIP]
> The simplest way to install Neo4j is via [Neo4j Desktop](https://neo4j.com/download/). It provides a user-friendly
> interface to manage Neo4j instances and databases.
> Alternatively, you can use FalkorDB on-premises via Docker and instantly start with the quickstart example:
```bash
docker run -p 6379:6379 -p 3000:3000 -it --rm falkordb/falkordb:latest
```
```bash
pip install graphiti-core
@ -156,7 +162,7 @@ pip install graphiti-core[anthropic,groq,google-genai]
For a complete working example, see the [Quickstart Example](./examples/quickstart/README.md) in the examples directory. The quickstart demonstrates:
1. Connecting to a Neo4j database
1. Connecting to a Neo4j or FalkorDB database
2. Initializing Graphiti indices and constraints
3. Adding episodes to the graph (both text and structured JSON)
4. Searching for relationships (edges) using hybrid search

View file

@ -46,14 +46,20 @@ logger = logging.getLogger(__name__)
load_dotenv()
# FalkorDB connection parameters
# Make sure FalkorDB on premises is running, see https://docs.falkordb.com/
falkor_uri = os.environ.get('FALKORDB_URI', 'falkor://localhost:6379')
falkor_user = os.environ.get('FALKORDB_USER', 'falkor')
falkor_password = os.environ.get('FALKORDB_PASSWORD', '')
if not falkor_uri:
raise ValueError('FALKORDB_URI must be set')
# Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
# By default, FalkorDB does not require a username or password,
# but you can set them via environment variables for added security.
#
# If you're using FalkorDB Cloud, set the environment variables accordingly.
# For on-premises use, you can leave them as None or set them to your preferred values.
#
# The default host and port are 'localhost' and '6379', respectively.
# You can override these values in your environment variables or directly in the code.
falkor_username = os.environ.get('FALKORDB_USERNAME', None)
falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
falkor_port = os.environ.get('FALKORDB_PORT', '6379')
async def main():
#################################################
@ -65,8 +71,8 @@ async def main():
#################################################
# Initialize Graphiti with FalkorDB connection
falkor_driver = FalkorDriver(uri=falkor_uri, user=falkor_user, password=falkor_password)
graphiti = Graphiti(uri=falkor_uri, graph_driver=falkor_driver)
falkor_driver = FalkorDriver(host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password)
graphiti = Graphiti(graph_driver=falkor_driver)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.

View file

@ -15,7 +15,6 @@ limitations under the License.
"""
import logging
from collections.abc import Coroutine
from datetime import datetime
from typing import Any
@ -66,22 +65,30 @@ class FalkorDriver(GraphDriver):
def __init__(
self,
uri: str,
user: str,
password: str,
host: str = 'localhost',
port: str = '6379',
username: str | None = None,
password: str | None = None,
falkor_db: FalkorDB | None = None,
):
super().__init__()
uri_parts = uri.split('://', 1)
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
"""
Initialize the FalkorDB driver.
self.client = FalkorDB(
host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
)
FalkorDB is a multi-tenant graph database.
To connect, provide the host and port.
The default parameters assume a local (on-premises) FalkorDB instance.
"""
super().__init__()
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
self.client = falkor_db
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is DEFAULT_DATABASE
if graph_name is None:
graph_name = 'DEFAULT_DATABASE'
graph_name = DEFAULT_DATABASE
return self.client.select_graph(graph_name)
async def execute_query(self, cypher_query_, **kwargs: Any):
@ -102,17 +109,36 @@ class FalkorDriver(GraphDriver):
raise
# Convert the result header to a list of strings
header = [h[1].decode('utf-8') for h in result.header]
return result.result_set, header, None
header = [h[1] for h in result.header]
# Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
records = []
for row in result.result_set:
record = {}
for i, field_name in enumerate(header):
if i < len(row):
record[field_name] = row[i]
else:
# If there are more fields in header than values in row, set to None
record[field_name] = None
records.append(record)
return records, header, None
def session(self, database: str | None) -> GraphDriverSession:
return FalkorDriverSession(self._get_graph(database))
async def close(self) -> None:
await self.client.connection.close()
"""Close the driver connection."""
if hasattr(self.client, 'aclose'):
await self.client.aclose()
elif hasattr(self.client.connection, 'aclose'):
await self.client.connection.aclose()
elif hasattr(self.client.connection, 'close'):
await self.client.connection.close()
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
return self.execute_query(
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> None:
await self.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
database_=database_,
)

View file

@ -101,7 +101,7 @@ class AddEpisodeResults(BaseModel):
class Graphiti:
def __init__(
self,
uri: str,
uri: str | None = None,
user: str | None = None,
password: str | None = None,
llm_client: LLMClient | None = None,
@ -162,7 +162,12 @@ class Graphiti:
Graphiti if you're using the default OpenAIClient.
"""
self.driver = graph_driver if graph_driver else Neo4jDriver(uri, user, password)
if graph_driver:
self.driver = graph_driver
else:
if uri is None:
raise ValueError("uri must be provided when graph_driver is None")
self.driver = Neo4jDriver(uri, user, password)
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content

View file

@ -31,7 +31,7 @@ from graphiti_core.errors import GroupIdValidationError
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'neo4j')
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', 'default_db')
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))

View file

@ -540,10 +540,18 @@ class CommunityNode(Node):
# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
created_at = parse_db_date(record['created_at'])
valid_at = parse_db_date(record['valid_at'])
if created_at is None:
raise ValueError(f"created_at cannot be None for episode {record.get('uuid', 'unknown')}")
if valid_at is None:
raise ValueError(f"valid_at cannot be None for episode {record.get('uuid', 'unknown')}")
return EpisodicNode(
content=record['content'],
created_at=parse_db_date(record['created_at']), # type: ignore
valid_at=parse_db_date(record['valid_at']), # type: ignore
created_at=created_at,
valid_at=valid_at,
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),

View file

@ -278,9 +278,6 @@ async def edge_similarity_search(
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
edges = [get_entity_edge_from_record(record) for record in records]
return edges
@ -377,8 +374,6 @@ async def node_fulltext_search(
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
nodes = [get_entity_node_from_record(record) for record in records]
@ -433,8 +428,7 @@ async def node_similarity_search(
database_=DEFAULT_DATABASE,
routing_='r',
)
if driver.provider == 'falkordb':
records = [dict(zip(header, row, strict=True)) for row in records]
nodes = [get_entity_node_from_record(record) for record in records]
return nodes

View file

@ -21,7 +21,7 @@ from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3
@ -140,10 +140,8 @@ async def retrieve_episodes(
episodes = [
EpisodicNode(
content=record['content'],
created_at=datetime.fromtimestamp(
record['created_at'].to_native().timestamp(), timezone.utc
),
valid_at=(record['valid_at'].to_native()),
created_at=parse_db_date(record['created_at']) or datetime.min.replace(tzinfo=timezone.utc),
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),

2835
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -18,6 +18,7 @@ dependencies = [
"tenacity>=9.0.0",
"numpy>=1.0.0",
"python-dotenv>=1.0.1",
"falkordb>=1.1.2,<2.0.0",
"posthog>=3.0.0",
]

View file

@ -1,3 +1,4 @@
[pytest]
markers =
integration: marks tests as integration tests
integration: marks tests as integration tests
asyncio_default_fixture_loop_scope = function

1
tests/driver/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Tests for database drivers."""

View file

@ -0,0 +1,386 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
from graphiti_core.helpers import DEFAULT_DATABASE
class TestFalkorDriver:
"""Comprehensive test suite for FalkorDB driver."""
def setup_method(self):
"""Set up test fixtures."""
self.mock_client = MagicMock()
with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
self.driver = FalkorDriver()
self.driver.client = self.mock_client
def test_init_with_connection_params(self):
"""Test initialization with connection parameters."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
driver = FalkorDriver(
host='test-host',
port='1234',
username='test-user',
password='test-pass'
)
assert driver.provider == 'falkordb'
mock_falkor_db.assert_called_once_with(
host='test-host',
port='1234',
username='test-user',
password='test-pass'
)
def test_init_with_falkor_db_instance(self):
"""Test initialization with a FalkorDB instance."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
mock_falkor_db = MagicMock()
driver = FalkorDriver(falkor_db=mock_falkor_db)
assert driver.provider == 'falkordb'
assert driver.client is mock_falkor_db
mock_falkor_db_class.assert_not_called()
def test_provider(self):
"""Test driver provider identification."""
assert self.driver.provider == 'falkordb'
def test_get_graph_with_name(self):
"""Test _get_graph with specific graph name."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph('test_graph')
self.mock_client.select_graph.assert_called_once_with('test_graph')
assert result is mock_graph
def test_get_graph_with_none_defaults_to_default_database(self):
"""Test _get_graph with None defaults to DEFAULT_DATABASE."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph(None)
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
assert result is mock_graph
@pytest.mark.asyncio
async def test_execute_query_success(self):
"""Test successful query execution."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
mock_result.result_set = [['row1col1', 'row1col2']]
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
result = await self.driver.execute_query(
'MATCH (n) RETURN n',
param1='value1',
database_='test_db'
)
self.mock_client.select_graph.assert_called_once_with('test_db')
mock_graph.query.assert_called_once_with(
'MATCH (n) RETURN n',
{'param1': 'value1'}
)
result_set, header, summary = result
assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
assert header == ['column1', 'column2']
assert summary is None
@pytest.mark.asyncio
async def test_execute_query_handles_index_already_exists_error(self):
"""Test handling of 'already indexed' error."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
result = await self.driver.execute_query('CREATE INDEX ...')
mock_logger.info.assert_called_once()
assert result is None
@pytest.mark.asyncio
async def test_execute_query_propagates_other_exceptions(self):
"""Test that other exceptions are properly propagated."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
with pytest.raises(Exception, match='Other error'):
await self.driver.execute_query('INVALID QUERY')
mock_logger.error.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query_converts_datetime_parameters(self):
"""Test that datetime objects in kwargs are converted to ISO strings."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = []
mock_result.result_set = []
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.driver.execute_query(
'CREATE (n:Node) SET n.created_at = $created_at',
created_at=test_datetime
)
call_args = mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
def test_session_creation(self):
"""Test session creation with specific database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session('test_db')
assert isinstance(session, FalkorDriverSession)
assert session.graph is mock_graph
self.mock_client.select_graph.assert_called_once_with('test_db')
def test_session_creation_with_none_uses_default_database(self):
"""Test session creation with None uses default database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session(None)
assert isinstance(session, FalkorDriverSession)
self.mock_client.select_graph.assert_called_once_with(DEFAULT_DATABASE)
@pytest.mark.asyncio
async def test_close_calls_connection_close(self):
"""Test driver close method calls connection close."""
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
self.mock_client.connection = mock_connection
# Ensure hasattr checks work correctly
del self.mock_client.aclose # Remove aclose if it exists
with patch('builtins.hasattr') as mock_hasattr:
# hasattr(self.client, 'aclose') returns False
# hasattr(self.client.connection, 'aclose') returns False
# hasattr(self.client.connection, 'close') returns True
mock_hasattr.side_effect = lambda obj, attr: (
attr == 'close' and obj is mock_connection
)
await self.driver.close()
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
async def test_delete_all_indexes(self):
"""Test delete_all_indexes method."""
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
await self.driver.delete_all_indexes('test_db')
mock_execute.assert_called_once_with(
'CALL db.indexes() YIELD name DROP INDEX name',
database_='test_db'
)
class TestFalkorDriverSession:
"""Test FalkorDB driver session functionality."""
def setup_method(self):
"""Set up test fixtures."""
self.mock_graph = MagicMock()
self.session = FalkorDriverSession(self.mock_graph)
@pytest.mark.asyncio
async def test_session_async_context_manager(self):
"""Test session can be used as async context manager."""
async with self.session as s:
assert s is self.session
@pytest.mark.asyncio
async def test_close_method(self):
"""Test session close method doesn't raise exceptions."""
await self.session.close() # Should not raise
@pytest.mark.asyncio
async def test_execute_write_passes_session_and_args(self):
"""Test execute_write method passes session and arguments correctly."""
async def test_func(session, *args, **kwargs):
assert session is self.session
assert args == ('arg1', 'arg2')
assert kwargs == {'key': 'value'}
return 'result'
result = await self.session.execute_write(
test_func, 'arg1', 'arg2', key='value'
)
assert result == 'result'
@pytest.mark.asyncio
async def test_run_single_query_with_parameters(self):
"""Test running a single query with parameters."""
self.mock_graph.query = AsyncMock()
await self.session.run(
'MATCH (n) RETURN n',
param1='value1',
param2='value2'
)
self.mock_graph.query.assert_called_once_with(
'MATCH (n) RETURN n',
{'param1': 'value1', 'param2': 'value2'}
)
@pytest.mark.asyncio
async def test_run_multiple_queries_as_list(self):
"""Test running multiple queries passed as list."""
self.mock_graph.query = AsyncMock()
queries = [
('MATCH (n) RETURN n', {'param1': 'value1'}),
('CREATE (n:Node)', {'param2': 'value2'})
]
await self.session.run(queries)
assert self.mock_graph.query.call_count == 2
calls = self.mock_graph.query.call_args_list
assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
@pytest.mark.asyncio
async def test_run_converts_datetime_objects_to_iso_strings(self):
"""Test that datetime objects are converted to ISO strings."""
self.mock_graph.query = AsyncMock()
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.session.run(
'CREATE (n:Node) SET n.created_at = $created_at',
created_at=test_datetime
)
self.mock_graph.query.assert_called_once()
call_args = self.mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
class TestDatetimeConversion:
"""Test datetime conversion utility function."""
def test_convert_datetime_dict(self):
"""Test datetime conversion in nested dictionary."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
input_dict = {
'string_val': 'test',
'datetime_val': test_datetime,
'nested_dict': {
'nested_datetime': test_datetime,
'nested_string': 'nested_test'
}
}
result = convert_datetimes_to_strings(input_dict)
assert result['string_val'] == 'test'
assert result['datetime_val'] == test_datetime.isoformat()
assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
assert result['nested_dict']['nested_string'] == 'nested_test'
def test_convert_datetime_list_and_tuple(self):
"""Test datetime conversion in lists and tuples."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
# Test list
input_list = ['test', test_datetime, ['nested', test_datetime]]
result_list = convert_datetimes_to_strings(input_list)
assert result_list[0] == 'test'
assert result_list[1] == test_datetime.isoformat()
assert result_list[2][1] == test_datetime.isoformat()
# Test tuple
input_tuple = ('test', test_datetime)
result_tuple = convert_datetimes_to_strings(input_tuple)
assert isinstance(result_tuple, tuple)
assert result_tuple[0] == 'test'
assert result_tuple[1] == test_datetime.isoformat()
def test_convert_single_datetime(self):
"""Test datetime conversion for single datetime object."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = convert_datetimes_to_strings(test_datetime)
assert result == test_datetime.isoformat()
def test_convert_other_types_unchanged(self):
"""Test that non-datetime types are returned unchanged."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
assert convert_datetimes_to_strings('string') == 'string'
assert convert_datetimes_to_strings(123) == 123
assert convert_datetimes_to_strings(None) is None
assert convert_datetimes_to_strings(True) is True
# Simple integration test
class TestFalkorDriverIntegration:
"""Simple integration test for FalkorDB driver."""
@pytest.mark.integration
@pytest.mark.asyncio
async def test_basic_integration_with_real_falkordb(self):
"""Basic integration test with real FalkorDB instance."""
pytest.importorskip('falkordb')
falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
falkor_port = os.getenv('FALKORDB_PORT', '6379')
try:
driver = FalkorDriver(host=falkor_host, port=falkor_port)
# Test basic query execution
result = await driver.execute_query('RETURN 1 as test')
assert result is not None
result_set, header, summary = result
assert header == ['test']
assert result_set == [{'test': 1}]
await driver.close()
except Exception as e:
pytest.skip(f"FalkorDB not available for integration test: {e}")

View file

@ -24,7 +24,7 @@ def test_lucene_sanitize():
queries = [
(
'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
'\This has every escape character \+ \- \&\& \|\| \! \( \) \{ \} \[ \] \^ \\" \~ \* \? \: \\\ \/',
'\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
),
('this has no escape characters', 'this has no escape characters'),
]

View file

@ -0,0 +1,160 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os
import sys
from datetime import datetime, timezone
import pytest
from dotenv import load_dotenv
from graphiti_core.driver.falkordb_driver import FalkorDriver
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_helpers import search_results_to_context_string
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
load_dotenv()
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
async def test_graphiti_falkordb_init():
logger = setup_logging()
falkor_driver = FalkorDriver(
host=FALKORDB_HOST,
port=FALKORDB_PORT,
username=FALKORDB_USER,
password=FALKORDB_PASSWORD
)
graphiti = Graphiti(graph_driver=falkor_driver)
results = await graphiti.search_(query='Who is the user?')
pretty_results = search_results_to_context_string(results)
logger.info(pretty_results)
await graphiti.close()
@pytest.mark.asyncio
async def test_graph_falkordb_integration():
falkor_driver = FalkorDriver(
host=FALKORDB_HOST,
port=FALKORDB_PORT,
username=FALKORDB_USER,
password=FALKORDB_PASSWORD
)
client = Graphiti(graph_driver=falkor_driver)
embedder = client.embedder
driver = client.driver
now = datetime.now(timezone.utc)
episode = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source='message',
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
)
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save
await semaphore_gather(*[node.save(driver) for node in nodes])
await semaphore_gather(*[edge.save(driver) for edge in edges])
# test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete
await semaphore_gather(*[node.delete(driver) for node in nodes])
await semaphore_gather(*[edge.delete(driver) for edge in edges])
await client.close()

View file

@ -0,0 +1,137 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from graphiti_core.driver.falkordb_driver import FalkorDriver
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id='test_group',
labels=['Entity'],
name_embedding=[0.5] * 1024,
summary='Entity Summary',
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id='test_group',
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id='test_group',
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.integration
async def test_entity_node_save_get_and_delete(sample_entity_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST,
port=FALKORDB_PORT,
username=FALKORDB_USER,
password=FALKORDB_PASSWORD
)
await sample_entity_node.save(falkor_driver)
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == 'test_group'
await sample_entity_node.delete(falkor_driver)
await falkor_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_community_node_save_get_and_delete(sample_community_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST,
port=FALKORDB_PORT,
username=FALKORDB_USER,
password=FALKORDB_PASSWORD
)
await sample_community_node.save(falkor_driver)
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == 'test_group'
assert retrieved.summary == 'Community summary'
await sample_community_node.delete(falkor_driver)
await falkor_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
falkor_driver = FalkorDriver(
host=FALKORDB_HOST,
port=FALKORDB_PORT,
username=FALKORDB_USER,
password=FALKORDB_PASSWORD
)
await sample_episodic_node.save(falkor_driver)
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == 'test_group'
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
await sample_episodic_node.delete(falkor_driver)
await falkor_driver.close()