diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 08117f408..bb6544fd9 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -111,6 +111,18 @@ def create_graph_engine( return KuzuAdapter(db_path=graph_file_path) + elif graph_database_provider == "kuzu-remote": + if not graph_database_url: + raise EnvironmentError("Missing required Kuzu remote URL.") + + from .kuzu.remote_kuzu_adapter import RemoteKuzuAdapter + + return RemoteKuzuAdapter( + api_url=graph_database_url, + username=graph_database_username, + password=graph_database_password, + ) + elif graph_database_provider == "memgraph": if not (graph_database_url and graph_database_username and graph_database_password): raise EnvironmentError("Missing required Memgraph credentials.") diff --git a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py new file mode 100644 index 000000000..c75b70f75 --- /dev/null +++ b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py @@ -0,0 +1,197 @@ +"""Adapter for remote Kuzu graph database via REST API.""" + +from cognee.shared.logging_utils import get_logger +import json +from typing import Dict, Any, List, Optional, Tuple +import aiohttp +from uuid import UUID + +from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter + +logger = get_logger() + + +class UUIDEncoder(json.JSONEncoder): + """Custom JSON encoder that handles UUID objects.""" + + def default(self, obj): + if isinstance(obj, UUID): + return str(obj) + return super().default(obj) + + +class RemoteKuzuAdapter(KuzuAdapter): + """Adapter for remote Kuzu graph database operations via REST API.""" + + def __init__(self, api_url: str, username: str, password: str): + """Initialize remote Kuzu database connection. + + Args: + api_url: URL of the Kuzu REST API + username: Optional username for API authentication + password: Optional password for API authentication + """ + # Initialize parent with a dummy path since we're using REST API + super().__init__("/tmp/kuzu_remote") + self.api_url = api_url + self.username = username + self.password = password + self._session = None + self._schema_initialized = False + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create an aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def close(self): + """Close the adapter and its session.""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def _make_request(self, endpoint: str, data: dict) -> dict: + """Make a request to the Kuzu API.""" + url = f"{self.api_url}{endpoint}" + session = await self._get_session() + try: + # Use custom encoder for UUID serialization + json_data = json.dumps(data, cls=UUIDEncoder) + async with session.post( + url, data=json_data, headers={"Content-Type": "application/json"} + ) as response: + if response.status != 200: + error_detail = await response.text() + logger.error( + f"API request failed with status {response.status}: {error_detail}\n" + f"Request data: {data}" + ) + raise aiohttp.ClientResponseError( + response.request_info, + response.history, + status=response.status, + message=error_detail, + ) + return await response.json() + except aiohttp.ClientError as e: + logger.error(f"API request failed: {str(e)}") + logger.error(f"Request data: {data}") + raise + + async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: + """Execute a Kuzu query via the REST API.""" + try: + # Initialize schema if needed + if not self._schema_initialized: + await self._initialize_schema() + + response = await self._make_request( + "/query", {"query": query, "parameters": params or {}} + ) + + # Convert response to list of tuples + results = [] + if "data" in response: + for row in response["data"]: + processed_row = [] + for val in row: + if isinstance(val, dict) and "properties" in val: + try: + props = json.loads(val["properties"]) + val.update(props) + del val["properties"] + except json.JSONDecodeError: + pass + processed_row.append(val) + results.append(tuple(processed_row)) + + return results + except Exception as e: + logger.error(f"Query execution failed: {str(e)}") + logger.error(f"Query: {query}") + logger.error(f"Parameters: {params}") + raise + + async def _check_schema_exists(self) -> bool: + """Check if the required schema exists without causing recursion.""" + try: + # Make a direct request to check schema using Cypher + response = await self._make_request( + "/query", + {"query": "MATCH (n:Node) RETURN COUNT(n) > 0", "parameters": {}}, + ) + return bool(response.get("data") and response["data"][0][0]) + except Exception as e: + logger.error(f"Failed to check schema: {e}") + return False + + async def _create_schema(self): + """Create the required schema tables.""" + try: + # Create Node table if it doesn't exist + try: + await self._make_request( + "/query", + { + "query": """ + CREATE NODE TABLE IF NOT EXISTS Node ( + id STRING, + name STRING, + type STRING, + properties STRING, + created_at TIMESTAMP, + updated_at TIMESTAMP, + PRIMARY KEY (id) + ) + """, + "parameters": {}, + }, + ) + except aiohttp.ClientResponseError as e: + if "already exists" not in str(e): + raise + + # Create EDGE table if it doesn't exist + try: + await self._make_request( + "/query", + { + "query": """ + CREATE REL TABLE IF NOT EXISTS EDGE ( + FROM Node TO Node, + relationship_name STRING, + properties STRING, + created_at TIMESTAMP, + updated_at TIMESTAMP + ) + """, + "parameters": {}, + }, + ) + except aiohttp.ClientResponseError as e: + if "already exists" not in str(e): + raise + + self._schema_initialized = True + logger.info("Schema initialized successfully") + + except Exception as e: + logger.error(f"Failed to create schema: {e}") + raise + + async def _initialize_schema(self): + """Initialize the database schema if it doesn't exist.""" + if self._schema_initialized: + return + + try: + if not await self._check_schema_exists(): + await self._create_schema() + else: + self._schema_initialized = True + logger.info("Schema already exists") + + except Exception as e: + logger.error(f"Failed to initialize schema: {e}") + raise diff --git a/cognee/infrastructure/databases/graph/kuzu/show_remote_kuzu_stats.py b/cognee/infrastructure/databases/graph/kuzu/show_remote_kuzu_stats.py new file mode 100644 index 000000000..d6f175be7 --- /dev/null +++ b/cognee/infrastructure/databases/graph/kuzu/show_remote_kuzu_stats.py @@ -0,0 +1,33 @@ +import asyncio +from cognee.infrastructure.databases.graph.kuzu.remote_kuzu_adapter import RemoteKuzuAdapter +from cognee.infrastructure.databases.graph.config import get_graph_config + + +async def main(): + config = get_graph_config() + adapter = RemoteKuzuAdapter( + config.graph_database_url, config.graph_database_username, config.graph_database_password + ) + try: + print("Node Count:") + result = await adapter.query("MATCH (n) RETURN COUNT(n) as count") + print(result) + + print("\nEdge Count:") + result = await adapter.query("MATCH ()-[r]->() RETURN COUNT(r) as count") + print(result) + + print("\nSample Nodes with Properties:") + result = await adapter.query("MATCH (n) RETURN n LIMIT 5") + print(result) + + print("\nSample Relationships with Properties:") + result = await adapter.query("MATCH (n1)-[r]->(n2) RETURN n1, r, n2 LIMIT 5") + print(result) + + finally: + await adapter.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 0419eea72..a7345f102 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -453,6 +453,8 @@ class SQLAlchemyAdapter: from cognee.infrastructure.files.storage import LocalStorage await self.engine.dispose(close=True) + db_directory = path.dirname(self.db_path) + LocalStorage.ensure_directory_exists(db_directory) with open(self.db_path, "w") as file: file.write("") else: diff --git a/cognee/tests/test_remote_kuzu.py b/cognee/tests/test_remote_kuzu.py new file mode 100644 index 000000000..ff089f27e --- /dev/null +++ b/cognee/tests/test_remote_kuzu.py @@ -0,0 +1,115 @@ +import os +import shutil +import cognee +import pathlib +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType +from cognee.modules.search.operations import get_history +from cognee.modules.users.methods import get_default_user +from cognee.infrastructure.databases.graph.config import get_graph_config + +logger = get_logger() + + +async def main(): + # Clean up test directories before starting + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_remote_kuzu") + ).resolve() + ) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_remote_kuzu") + ).resolve() + ) + + try: + # Set Kuzu as the graph database provider + cognee.config.set_graph_database_provider("kuzu") + cognee.config.data_root_directory(data_directory_path) + cognee.config.system_root_directory(cognee_directory_path) + + # Configure remote Kuzu database using environment variables + os.environ["KUZU_HOST"] = os.getenv("KUZU_HOST", "localhost") + os.environ["KUZU_PORT"] = os.getenv("KUZU_PORT", "8000") + os.environ["KUZU_USERNAME"] = os.getenv("KUZU_USERNAME", "kuzu") + os.environ["KUZU_PASSWORD"] = os.getenv("KUZU_PASSWORD", "kuzu") + os.environ["KUZU_DATABASE"] = os.getenv("KUZU_DATABASE", "cognee_test") + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + dataset_name = "cs_explanations" + + explanation_file_path = os.path.join( + pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" + ) + await cognee.add([explanation_file_path], dataset_name) + + text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. + At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states. + Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible. + The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly. + Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate. + In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited. + """ + await cognee.add([text], dataset_name) + + await cognee.cognify([dataset_name]) + + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] + random_node_name = random_node.payload["text"] + + search_results = await cognee.search( + query_type=SearchType.INSIGHTS, query_text=random_node_name + ) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted sentences are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search( + query_type=SearchType.CHUNKS, query_text=random_node_name + ) + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted chunks are:\n") + for result in search_results: + print(f"{result}\n") + + search_results = await cognee.search( + query_type=SearchType.SUMMARIES, query_text=random_node_name + ) + assert len(search_results) != 0, "Query related summaries don't exist." + print("\nExtracted summaries are:\n") + for result in search_results: + print(f"{result}\n") + + user = await get_default_user() + history = await get_history(user.id) + assert len(history) == 6, "Search history is not correct." + + await cognee.prune.prune_data() + assert not os.path.isdir(data_directory_path), "Local data files are not deleted" + + await cognee.prune.prune_system(metadata=True) + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 0 and len(edges) == 0, "Remote Kuzu graph database is not empty" + + finally: + # Ensure cleanup even if tests fail + for path in [data_directory_path, cognee_directory_path]: + if os.path.exists(path): + shutil.rmtree(path) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/cognee/tests/test_remote_kuzu_stress.py b/cognee/tests/test_remote_kuzu_stress.py new file mode 100644 index 000000000..49b550afb --- /dev/null +++ b/cognee/tests/test_remote_kuzu_stress.py @@ -0,0 +1,159 @@ +import asyncio +import random +import time +from cognee.infrastructure.databases.graph.kuzu.remote_kuzu_adapter import RemoteKuzuAdapter +from cognee.infrastructure.databases.graph.config import get_graph_config +from cognee.shared.logging_utils import get_logger + +# Test configuration +BATCH_SIZE = 5000 +NUM_BATCHES = 10 +TOTAL_NODES = BATCH_SIZE * NUM_BATCHES +TOTAL_RELATIONSHIPS = TOTAL_NODES - 1 + +logger = get_logger() + + +async def create_node(adapter, node): + query = ( + "CREATE (n:TestNode {" + f"id: '{node['id']}', " + f"name: '{node['name']}', " + f"value: {node['value']}" + "})" + ) + await adapter.query(query) + + +async def create_relationship(adapter, source_id, target_id): + query = ( + "MATCH (n1:TestNode {id: '" + str(source_id) + "'}), " + "(n2:TestNode {id: '" + str(target_id) + "'}) " + "CREATE (n1)-[r:CONNECTS_TO {weight: " + str(random.random()) + "}]->(n2)" + ) + await adapter.query(query) + + +async def process_batch(adapter, start_id, batch_size): + batch_start = time.time() + batch_nodes = [] + + # Prepare batch data + logger.info(f"Preparing batch {start_id // batch_size + 1}/{NUM_BATCHES}...") + for j in range(batch_size): + node_id = start_id + j + properties = { + "id": str(node_id), + "name": f"TestNode_{node_id}", + "value": random.randint(1, 1000), + } + batch_nodes.append(properties) + + # Create nodes concurrently + logger.info( + f"Creating {batch_size} nodes for batch {start_id // batch_size + 1}/{NUM_BATCHES}..." + ) + nodes_start = time.time() + node_tasks = [create_node(adapter, node) for node in batch_nodes] + await asyncio.gather(*node_tasks) + nodes_time = time.time() - nodes_start + + # Create relationships concurrently + logger.info(f"Creating relationships for batch {start_id // batch_size + 1}/{NUM_BATCHES}...") + rels_start = time.time() + rel_tasks = [ + create_relationship(adapter, batch_nodes[j]["id"], batch_nodes[j + 1]["id"]) + for j in range(len(batch_nodes) - 1) + ] + await asyncio.gather(*rel_tasks) + rels_time = time.time() - rels_start + + batch_time = time.time() - batch_start + logger.info(f"Batch {start_id // batch_size + 1}/{NUM_BATCHES} completed in {batch_time:.2f}s") + logger.info(f" - Nodes creation: {nodes_time:.2f}s") + logger.info(f" - Relationships creation: {rels_time:.2f}s") + return batch_time + + +async def create_test_data(adapter, batch_size=BATCH_SIZE): + tasks = [] + + # Create tasks for each batch + for i in range(0, TOTAL_NODES, batch_size): + task = asyncio.create_task(process_batch(adapter, i, batch_size)) + tasks.append(task) + + # Wait for all batches to complete + batch_times = await asyncio.gather(*tasks) + return sum(batch_times) + + +async def main(): + config = get_graph_config() + adapter = RemoteKuzuAdapter( + config.graph_database_url, config.graph_database_username, config.graph_database_password + ) + + try: + logger.info("=== Starting Kuzu Stress Test ===") + logger.info(f"Configuration: {NUM_BATCHES} batches of {BATCH_SIZE} nodes each") + logger.info(f"Total nodes to create: {TOTAL_NODES}") + logger.info(f"Total relationships to create: {TOTAL_RELATIONSHIPS}") + start_time = time.time() + + # Drop existing tables in correct order (relationships first, then nodes) + logger.info("[1/5] Dropping existing tables...") + await adapter.query("DROP TABLE IF EXISTS CONNECTS_TO") + await adapter.query("DROP TABLE IF EXISTS TestNode") + + # Create node table + logger.info("[2/5] Creating node table structure...") + await adapter.query(""" + CREATE NODE TABLE TestNode ( + id STRING, + name STRING, + value INT64, + PRIMARY KEY (id) + ) + """) + + # Create relationship table + logger.info("[3/5] Creating relationship table structure...") + await adapter.query(""" + CREATE REL TABLE CONNECTS_TO ( + FROM TestNode TO TestNode, + weight DOUBLE + ) + """) + + # Clear existing test data + logger.info("[4/5] Clearing existing test data...") + await adapter.query("MATCH (n:TestNode) DETACH DELETE n") + + # Create new test data + logger.info( + f"[5/5] Creating test data ({NUM_BATCHES} concurrent batches of {BATCH_SIZE} nodes each)..." + ) + total_batch_time = await create_test_data(adapter) + + end_time = time.time() + total_duration = end_time - start_time + + # Verify the data + logger.info("Verifying data...") + result = await adapter.query("MATCH (n:TestNode) RETURN COUNT(n) as count") + logger.info(f"Total nodes created: {result}") + + result = await adapter.query("MATCH ()-[r:CONNECTS_TO]->() RETURN COUNT(r) as count") + logger.info(f"Total relationships created: {result}") + + logger.info("=== Test Summary ===") + logger.info(f"Total batch processing time: {total_batch_time:.2f} seconds") + logger.info(f"Total execution time: {total_duration:.2f} seconds") + + finally: + await adapter.close() + + +if __name__ == "__main__": + asyncio.run(main())