feat: remote kuzu adapter (#781)

<!-- .github/pull_request_template.md -->

## Description
Enables the use of a remote Kuzu instance via a RESTful API.

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
Daniel Molnar 2025-06-09 15:27:16 +02:00 committed by GitHub
parent 3da893c131
commit 4eb71ccaf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 518 additions and 0 deletions

View file

@ -111,6 +111,18 @@ def create_graph_engine(
return KuzuAdapter(db_path=graph_file_path)
elif graph_database_provider == "kuzu-remote":
if not graph_database_url:
raise EnvironmentError("Missing required Kuzu remote URL.")
from .kuzu.remote_kuzu_adapter import RemoteKuzuAdapter
return RemoteKuzuAdapter(
api_url=graph_database_url,
username=graph_database_username,
password=graph_database_password,
)
elif graph_database_provider == "memgraph":
if not (graph_database_url and graph_database_username and graph_database_password):
raise EnvironmentError("Missing required Memgraph credentials.")

View file

@ -0,0 +1,197 @@
"""Adapter for remote Kuzu graph database via REST API."""
from cognee.shared.logging_utils import get_logger
import json
from typing import Dict, Any, List, Optional, Tuple
import aiohttp
from uuid import UUID
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
logger = get_logger()
class UUIDEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles UUID objects."""
def default(self, obj):
if isinstance(obj, UUID):
return str(obj)
return super().default(obj)
class RemoteKuzuAdapter(KuzuAdapter):
"""Adapter for remote Kuzu graph database operations via REST API."""
def __init__(self, api_url: str, username: str, password: str):
"""Initialize remote Kuzu database connection.
Args:
api_url: URL of the Kuzu REST API
username: Optional username for API authentication
password: Optional password for API authentication
"""
# Initialize parent with a dummy path since we're using REST API
super().__init__("/tmp/kuzu_remote")
self.api_url = api_url
self.username = username
self.password = password
self._session = None
self._schema_initialized = False
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self):
"""Close the adapter and its session."""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
async def _make_request(self, endpoint: str, data: dict) -> dict:
"""Make a request to the Kuzu API."""
url = f"{self.api_url}{endpoint}"
session = await self._get_session()
try:
# Use custom encoder for UUID serialization
json_data = json.dumps(data, cls=UUIDEncoder)
async with session.post(
url, data=json_data, headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_detail = await response.text()
logger.error(
f"API request failed with status {response.status}: {error_detail}\n"
f"Request data: {data}"
)
raise aiohttp.ClientResponseError(
response.request_info,
response.history,
status=response.status,
message=error_detail,
)
return await response.json()
except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}")
raise
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""Execute a Kuzu query via the REST API."""
try:
# Initialize schema if needed
if not self._schema_initialized:
await self._initialize_schema()
response = await self._make_request(
"/query", {"query": query, "parameters": params or {}}
)
# Convert response to list of tuples
results = []
if "data" in response:
for row in response["data"]:
processed_row = []
for val in row:
if isinstance(val, dict) and "properties" in val:
try:
props = json.loads(val["properties"])
val.update(props)
del val["properties"]
except json.JSONDecodeError:
pass
processed_row.append(val)
results.append(tuple(processed_row))
return results
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
logger.error(f"Query: {query}")
logger.error(f"Parameters: {params}")
raise
async def _check_schema_exists(self) -> bool:
"""Check if the required schema exists without causing recursion."""
try:
# Make a direct request to check schema using Cypher
response = await self._make_request(
"/query",
{"query": "MATCH (n:Node) RETURN COUNT(n) > 0", "parameters": {}},
)
return bool(response.get("data") and response["data"][0][0])
except Exception as e:
logger.error(f"Failed to check schema: {e}")
return False
async def _create_schema(self):
"""Create the required schema tables."""
try:
# Create Node table if it doesn't exist
try:
await self._make_request(
"/query",
{
"query": """
CREATE NODE TABLE IF NOT EXISTS Node (
id STRING,
name STRING,
type STRING,
properties STRING,
created_at TIMESTAMP,
updated_at TIMESTAMP,
PRIMARY KEY (id)
)
""",
"parameters": {},
},
)
except aiohttp.ClientResponseError as e:
if "already exists" not in str(e):
raise
# Create EDGE table if it doesn't exist
try:
await self._make_request(
"/query",
{
"query": """
CREATE REL TABLE IF NOT EXISTS EDGE (
FROM Node TO Node,
relationship_name STRING,
properties STRING,
created_at TIMESTAMP,
updated_at TIMESTAMP
)
""",
"parameters": {},
},
)
except aiohttp.ClientResponseError as e:
if "already exists" not in str(e):
raise
self._schema_initialized = True
logger.info("Schema initialized successfully")
except Exception as e:
logger.error(f"Failed to create schema: {e}")
raise
async def _initialize_schema(self):
"""Initialize the database schema if it doesn't exist."""
if self._schema_initialized:
return
try:
if not await self._check_schema_exists():
await self._create_schema()
else:
self._schema_initialized = True
logger.info("Schema already exists")
except Exception as e:
logger.error(f"Failed to initialize schema: {e}")
raise

View file

@ -0,0 +1,33 @@
import asyncio
from cognee.infrastructure.databases.graph.kuzu.remote_kuzu_adapter import RemoteKuzuAdapter
from cognee.infrastructure.databases.graph.config import get_graph_config
async def main():
config = get_graph_config()
adapter = RemoteKuzuAdapter(
config.graph_database_url, config.graph_database_username, config.graph_database_password
)
try:
print("Node Count:")
result = await adapter.query("MATCH (n) RETURN COUNT(n) as count")
print(result)
print("\nEdge Count:")
result = await adapter.query("MATCH ()-[r]->() RETURN COUNT(r) as count")
print(result)
print("\nSample Nodes with Properties:")
result = await adapter.query("MATCH (n) RETURN n LIMIT 5")
print(result)
print("\nSample Relationships with Properties:")
result = await adapter.query("MATCH (n1)-[r]->(n2) RETURN n1, r, n2 LIMIT 5")
print(result)
finally:
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -453,6 +453,8 @@ class SQLAlchemyAdapter:
from cognee.infrastructure.files.storage import LocalStorage
await self.engine.dispose(close=True)
db_directory = path.dirname(self.db_path)
LocalStorage.ensure_directory_exists(db_directory)
with open(self.db_path, "w") as file:
file.write("")
else:

View file

@ -0,0 +1,115 @@
import os
import shutil
import cognee
import pathlib
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.graph.config import get_graph_config
logger = get_logger()
async def main():
# Clean up test directories before starting
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_remote_kuzu")
).resolve()
)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_remote_kuzu")
).resolve()
)
try:
# Set Kuzu as the graph database provider
cognee.config.set_graph_database_provider("kuzu")
cognee.config.data_root_directory(data_directory_path)
cognee.config.system_root_directory(cognee_directory_path)
# Configure remote Kuzu database using environment variables
os.environ["KUZU_HOST"] = os.getenv("KUZU_HOST", "localhost")
os.environ["KUZU_PORT"] = os.getenv("KUZU_PORT", "8000")
os.environ["KUZU_USERNAME"] = os.getenv("KUZU_USERNAME", "kuzu")
os.environ["KUZU_PASSWORD"] = os.getenv("KUZU_PASSWORD", "kuzu")
os.environ["KUZU_DATABASE"] = os.getenv("KUZU_DATABASE", "cognee_test")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.CHUNKS, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted summaries are:\n")
for result in search_results:
print(f"{result}\n")
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Remote Kuzu graph database is not empty"
finally:
# Ensure cleanup even if tests fail
for path in [data_directory_path, cognee_directory_path]:
if os.path.exists(path):
shutil.rmtree(path)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -0,0 +1,159 @@
import asyncio
import random
import time
from cognee.infrastructure.databases.graph.kuzu.remote_kuzu_adapter import RemoteKuzuAdapter
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.shared.logging_utils import get_logger
# Test configuration
BATCH_SIZE = 5000
NUM_BATCHES = 10
TOTAL_NODES = BATCH_SIZE * NUM_BATCHES
TOTAL_RELATIONSHIPS = TOTAL_NODES - 1
logger = get_logger()
async def create_node(adapter, node):
query = (
"CREATE (n:TestNode {"
f"id: '{node['id']}', "
f"name: '{node['name']}', "
f"value: {node['value']}"
"})"
)
await adapter.query(query)
async def create_relationship(adapter, source_id, target_id):
query = (
"MATCH (n1:TestNode {id: '" + str(source_id) + "'}), "
"(n2:TestNode {id: '" + str(target_id) + "'}) "
"CREATE (n1)-[r:CONNECTS_TO {weight: " + str(random.random()) + "}]->(n2)"
)
await adapter.query(query)
async def process_batch(adapter, start_id, batch_size):
batch_start = time.time()
batch_nodes = []
# Prepare batch data
logger.info(f"Preparing batch {start_id // batch_size + 1}/{NUM_BATCHES}...")
for j in range(batch_size):
node_id = start_id + j
properties = {
"id": str(node_id),
"name": f"TestNode_{node_id}",
"value": random.randint(1, 1000),
}
batch_nodes.append(properties)
# Create nodes concurrently
logger.info(
f"Creating {batch_size} nodes for batch {start_id // batch_size + 1}/{NUM_BATCHES}..."
)
nodes_start = time.time()
node_tasks = [create_node(adapter, node) for node in batch_nodes]
await asyncio.gather(*node_tasks)
nodes_time = time.time() - nodes_start
# Create relationships concurrently
logger.info(f"Creating relationships for batch {start_id // batch_size + 1}/{NUM_BATCHES}...")
rels_start = time.time()
rel_tasks = [
create_relationship(adapter, batch_nodes[j]["id"], batch_nodes[j + 1]["id"])
for j in range(len(batch_nodes) - 1)
]
await asyncio.gather(*rel_tasks)
rels_time = time.time() - rels_start
batch_time = time.time() - batch_start
logger.info(f"Batch {start_id // batch_size + 1}/{NUM_BATCHES} completed in {batch_time:.2f}s")
logger.info(f" - Nodes creation: {nodes_time:.2f}s")
logger.info(f" - Relationships creation: {rels_time:.2f}s")
return batch_time
async def create_test_data(adapter, batch_size=BATCH_SIZE):
tasks = []
# Create tasks for each batch
for i in range(0, TOTAL_NODES, batch_size):
task = asyncio.create_task(process_batch(adapter, i, batch_size))
tasks.append(task)
# Wait for all batches to complete
batch_times = await asyncio.gather(*tasks)
return sum(batch_times)
async def main():
config = get_graph_config()
adapter = RemoteKuzuAdapter(
config.graph_database_url, config.graph_database_username, config.graph_database_password
)
try:
logger.info("=== Starting Kuzu Stress Test ===")
logger.info(f"Configuration: {NUM_BATCHES} batches of {BATCH_SIZE} nodes each")
logger.info(f"Total nodes to create: {TOTAL_NODES}")
logger.info(f"Total relationships to create: {TOTAL_RELATIONSHIPS}")
start_time = time.time()
# Drop existing tables in correct order (relationships first, then nodes)
logger.info("[1/5] Dropping existing tables...")
await adapter.query("DROP TABLE IF EXISTS CONNECTS_TO")
await adapter.query("DROP TABLE IF EXISTS TestNode")
# Create node table
logger.info("[2/5] Creating node table structure...")
await adapter.query("""
CREATE NODE TABLE TestNode (
id STRING,
name STRING,
value INT64,
PRIMARY KEY (id)
)
""")
# Create relationship table
logger.info("[3/5] Creating relationship table structure...")
await adapter.query("""
CREATE REL TABLE CONNECTS_TO (
FROM TestNode TO TestNode,
weight DOUBLE
)
""")
# Clear existing test data
logger.info("[4/5] Clearing existing test data...")
await adapter.query("MATCH (n:TestNode) DETACH DELETE n")
# Create new test data
logger.info(
f"[5/5] Creating test data ({NUM_BATCHES} concurrent batches of {BATCH_SIZE} nodes each)..."
)
total_batch_time = await create_test_data(adapter)
end_time = time.time()
total_duration = end_time - start_time
# Verify the data
logger.info("Verifying data...")
result = await adapter.query("MATCH (n:TestNode) RETURN COUNT(n) as count")
logger.info(f"Total nodes created: {result}")
result = await adapter.query("MATCH ()-[r:CONNECTS_TO]->() RETURN COUNT(r) as count")
logger.info(f"Total relationships created: {result}")
logger.info("=== Test Summary ===")
logger.info(f"Total batch processing time: {total_batch_time:.2f} seconds")
logger.info(f"Total execution time: {total_duration:.2f} seconds")
finally:
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())