feat: Kuzu integration (#628)
<!-- .github/pull_request_template.md --> ## Description Let's scope it out. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for the Kuzu graph database provider, enhancing graph operations and data management capabilities. - Added a comprehensive adapter for Kuzu, facilitating various graph database operations. - Expanded the enumeration of graph database types to include Kuzu. - **Tests** - Launched comprehensive asynchronous tests to validate the new Kuzu graph integration’s performance and reliability. - **Chores** - Updated dependency settings and continuous integration workflows to include the Kuzu provider, ensuring smoother deployments and improved system quality. - Enhanced configuration documentation to clarify Kuzu database requirements. - Modified Dockerfile to include Kuzu in the installation extras. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
parent
e147fa5bde
commit
69950a04dd
11 changed files with 1646 additions and 282 deletions
|
|
@ -23,9 +23,9 @@ EMBEDDING_API_VERSION=""
|
|||
EMBEDDING_DIMENSIONS=3072
|
||||
EMBEDDING_MAX_TOKENS=8191
|
||||
|
||||
# "neo4j" or "networkx"
|
||||
# "neo4j", "networkx" or "kuzu"
|
||||
GRAPH_DATABASE_PROVIDER="networkx"
|
||||
# Not needed if using networkx
|
||||
# Only needed if using neo4j
|
||||
GRAPH_DATABASE_URL=
|
||||
GRAPH_DATABASE_USERNAME=
|
||||
GRAPH_DATABASE_PASSWORD=
|
||||
|
|
|
|||
54
.github/workflows/test_kuzu.yml
vendored
Normal file
54
.github/workflows/test_kuzu.yml
vendored
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
name: test | kuzu
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
|
||||
jobs:
|
||||
run_kuzu_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -E kuzu --no-interaction
|
||||
|
||||
- name: Run Kuzu tests
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: poetry run python ./cognee/tests/test_kuzu.py
|
||||
|
|
@ -3,7 +3,7 @@ FROM python:3.11-slim
|
|||
# Define Poetry extras to install
|
||||
ARG POETRY_EXTRAS="\
|
||||
# Storage & Databases \
|
||||
filesystem postgres weaviate qdrant neo4j falkordb milvus \
|
||||
filesystem postgres weaviate qdrant neo4j falkordb milvus kuzu \
|
||||
# Notebooks & Interactive Environments \
|
||||
notebook \
|
||||
# LLM & AI Frameworks \
|
||||
|
|
|
|||
|
|
@ -59,6 +59,14 @@ def create_graph_engine(
|
|||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif graph_database_provider == "kuzu":
|
||||
if not graph_file_path:
|
||||
raise EnvironmentError("Missing required Kuzu database path.")
|
||||
|
||||
from .kuzu.adapter import KuzuAdapter
|
||||
|
||||
return KuzuAdapter(db_path=graph_file_path)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=graph_file_path)
|
||||
|
|
|
|||
0
cognee/infrastructure/databases/graph/kuzu/__init__.py
Normal file
0
cognee/infrastructure/databases/graph/kuzu/__init__.py
Normal file
925
cognee/infrastructure/databases/graph/kuzu/adapter.py
Normal file
925
cognee/infrastructure/databases/graph/kuzu/adapter.py
Normal file
|
|
@ -0,0 +1,925 @@
|
|||
"""Adapter for Kuzu graph database."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from contextlib import asynccontextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import kuzu
|
||||
from kuzu.database import Database
|
||||
from kuzu import Connection
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
import aiofiles
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KuzuAdapter(GraphDBInterface):
|
||||
"""Adapter for Kuzu graph database operations with improved consistency and async support."""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""Initialize Kuzu database connection and schema."""
|
||||
self.db_path = db_path # Path for the database directory
|
||||
self.db: Optional[Database] = None
|
||||
self.connection: Optional[Connection] = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self._initialize_connection()
|
||||
|
||||
def _initialize_connection(self) -> None:
|
||||
"""Initialize the Kuzu database connection and schema."""
|
||||
try:
|
||||
os.makedirs(self.db_path, exist_ok=True)
|
||||
self.db = Database(self.db_path)
|
||||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
# Create node table with essential fields and timestamp
|
||||
self.connection.execute("""
|
||||
CREATE NODE TABLE IF NOT EXISTS Node(
|
||||
id STRING PRIMARY KEY,
|
||||
text STRING,
|
||||
type STRING,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
properties STRING
|
||||
)
|
||||
""")
|
||||
# Create relationship table with timestamp
|
||||
self.connection.execute("""
|
||||
CREATE REL TABLE IF NOT EXISTS EDGE(
|
||||
FROM Node TO Node,
|
||||
relationship_name STRING,
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
properties STRING
|
||||
)
|
||||
""")
|
||||
logger.debug("Kuzu database initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""Execute a Kuzu query asynchronously with automatic reconnection."""
|
||||
loop = asyncio.get_running_loop()
|
||||
params = params or {}
|
||||
|
||||
def blocking_query():
|
||||
try:
|
||||
if not self.connection:
|
||||
logger.debug("Reconnecting to Kuzu database...")
|
||||
self._initialize_connection()
|
||||
|
||||
result = self.connection.execute(query, params)
|
||||
rows = []
|
||||
|
||||
while result.has_next():
|
||||
row = result.get_next()
|
||||
processed_rows = []
|
||||
for val in row:
|
||||
if hasattr(val, "as_py"):
|
||||
val = val.as_py()
|
||||
processed_rows.append(val)
|
||||
rows.append(tuple(processed_rows))
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error(f"Query execution failed: {str(e)}")
|
||||
raise
|
||||
|
||||
return await loop.run_in_executor(self.executor, blocking_query)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Get a database session.
|
||||
|
||||
Kuzu doesn't have session management like Neo4j, so this provides API compatibility.
|
||||
"""
|
||||
try:
|
||||
yield self.connection
|
||||
finally:
|
||||
pass
|
||||
|
||||
def _parse_node(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a raw node result (with JSON properties) into a dictionary."""
|
||||
if data.get("properties"):
|
||||
try:
|
||||
props = json.loads(data["properties"])
|
||||
# Remove the JSON field and merge its contents
|
||||
data.pop("properties")
|
||||
data.update(props)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
|
||||
return data
|
||||
|
||||
def _parse_node_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
if isinstance(data, dict) and "properties" in data and data["properties"]:
|
||||
props = json.loads(data["properties"])
|
||||
data.update(props)
|
||||
del data["properties"]
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {data.get('id')}")
|
||||
return data
|
||||
|
||||
# Helper method for building edge queries
|
||||
|
||||
def _edge_query_and_params(
|
||||
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
||||
) -> Tuple[str, dict]:
|
||||
"""Build the edge creation query and parameters."""
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
query = """
|
||||
MATCH (from:Node), (to:Node)
|
||||
WHERE from.id = $from_id AND to.id = $to_id
|
||||
CREATE (from)-[r:EDGE {
|
||||
relationship_name: $relationship_name,
|
||||
created_at: timestamp($created_at),
|
||||
updated_at: timestamp($updated_at),
|
||||
properties: $properties
|
||||
}]->(to)
|
||||
"""
|
||||
params = {
|
||||
"from_id": from_node,
|
||||
"to_id": to_node,
|
||||
"relationship_name": relationship_name,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"properties": json.dumps(properties, cls=JSONEncoder),
|
||||
}
|
||||
return query, params
|
||||
|
||||
# Node Operations
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node exists."""
|
||||
query_str = "MATCH (n:Node) WHERE n.id = $id RETURN COUNT(n) > 0"
|
||||
result = await self.query(query_str, {"id": node_id})
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def add_node(self, node: DataPoint) -> None:
|
||||
"""Add a single node to the graph if it doesn't exist."""
|
||||
try:
|
||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
||||
|
||||
# Extract core fields with defaults if not present
|
||||
core_properties = {
|
||||
"id": str(properties.get("id", "")),
|
||||
"text": str(properties.get("text", "")),
|
||||
"type": str(properties.get("type", "")),
|
||||
}
|
||||
|
||||
# Remove core fields from other properties
|
||||
for key in core_properties:
|
||||
properties.pop(key, None)
|
||||
|
||||
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
||||
|
||||
# Check if node exists
|
||||
exists = await self.has_node(core_properties["id"])
|
||||
|
||||
if not exists:
|
||||
# Add timestamps for new node
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
fields = []
|
||||
params = {}
|
||||
for key, value in core_properties.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
fields.append(f"{key}: ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
# Add timestamp fields
|
||||
fields.extend(
|
||||
["created_at: timestamp($created_at)", "updated_at: timestamp($updated_at)"]
|
||||
)
|
||||
params.update({"created_at": now, "updated_at": now})
|
||||
|
||||
create_query = f"""
|
||||
CREATE (n:Node {{{", ".join(fields)}}})
|
||||
"""
|
||||
await self.query(create_query, params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add node: {e}")
|
||||
raise
|
||||
|
||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||
"""Add multiple nodes to the graph in a batch operation."""
|
||||
if not nodes:
|
||||
return
|
||||
|
||||
try:
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Prepare all nodes data first
|
||||
node_params = []
|
||||
for node in nodes:
|
||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
||||
|
||||
# Extract core fields
|
||||
core_properties = {
|
||||
"id": str(properties.get("id", "")),
|
||||
"text": str(properties.get("text", "")),
|
||||
"type": str(properties.get("type", "")),
|
||||
}
|
||||
|
||||
# Remove core fields from other properties
|
||||
for key in core_properties:
|
||||
properties.pop(key, None)
|
||||
|
||||
node_params.append(
|
||||
{
|
||||
**core_properties,
|
||||
"properties": json.dumps(properties, cls=JSONEncoder),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
|
||||
if node_params:
|
||||
# First check which nodes don't exist yet
|
||||
check_query = """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Node)
|
||||
WHERE n.id = node.id
|
||||
RETURN n.id
|
||||
"""
|
||||
existing_nodes = await self.query(check_query, {"nodes": node_params})
|
||||
existing_ids = {str(row[0]) for row in existing_nodes}
|
||||
|
||||
# Filter out existing nodes
|
||||
new_nodes = [node for node in node_params if node["id"] not in existing_ids]
|
||||
|
||||
if new_nodes:
|
||||
# Batch create new nodes
|
||||
create_query = """
|
||||
UNWIND $nodes AS node
|
||||
CREATE (n:Node {
|
||||
id: node.id,
|
||||
text: node.text,
|
||||
type: node.type,
|
||||
properties: node.properties,
|
||||
created_at: timestamp(node.created_at),
|
||||
updated_at: timestamp(node.updated_at)
|
||||
})
|
||||
"""
|
||||
await self.query(create_query, {"nodes": new_nodes})
|
||||
logger.debug(f"Added {len(new_nodes)} new nodes in batch")
|
||||
else:
|
||||
logger.debug("No new nodes to add - all nodes already exist")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add nodes in batch: {e}")
|
||||
raise
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
"""Delete a node and its relationships."""
|
||||
query_str = "MATCH (n:Node) WHERE n.id = $id DETACH DELETE n"
|
||||
await self.query(query_str, {"id": node_id})
|
||||
|
||||
async def delete_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Delete multiple nodes at once."""
|
||||
query_str = "MATCH (n:Node) WHERE n.id IN $ids DETACH DELETE n"
|
||||
await self.query(query_str, {"ids": node_ids})
|
||||
|
||||
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract a node by its ID."""
|
||||
query_str = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id = $id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
result = await self.query(query_str, {"id": node_id})
|
||||
if result and result[0]:
|
||||
node_data = self._parse_node(result[0][0])
|
||||
return node_data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract node {node_id}: {e}")
|
||||
return None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Extract multiple nodes by their IDs."""
|
||||
query_str = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.id IN $node_ids
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
results = await self.query(query_str, {"node_ids": node_ids})
|
||||
# Parse each node using the same helper function
|
||||
nodes = [self._parse_node(row[0]) for row in results if row[0]]
|
||||
return nodes
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract nodes: {e}")
|
||||
return []
|
||||
|
||||
# Edge Operations
|
||||
|
||||
async def has_edge(self, from_node: str, to_node: str, edge_label: str) -> bool:
|
||||
"""Check if an edge exists between nodes with the given relationship name."""
|
||||
query_str = """
|
||||
MATCH (from:Node)-[r:EDGE]->(to:Node)
|
||||
WHERE from.id = $from_id AND to.id = $to_id AND r.relationship_name = $edge_label
|
||||
RETURN COUNT(r) > 0
|
||||
"""
|
||||
result = await self.query(
|
||||
query_str, {"from_id": from_node, "to_id": to_node, "edge_label": edge_label}
|
||||
)
|
||||
return result[0][0] if result else False
|
||||
|
||||
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""Check if multiple edges exist in a batch operation."""
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Transform edges into format needed for batch query
|
||||
edge_params = [
|
||||
{
|
||||
"from_id": str(from_node), # Ensure string type
|
||||
"to_id": str(to_node), # Ensure string type
|
||||
"relationship_name": str(edge_label), # Ensure string type
|
||||
}
|
||||
for from_node, to_node, edge_label in edges
|
||||
]
|
||||
|
||||
# Batch check query with direct string comparison
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from:Node)-[r:EDGE]->(to:Node)
|
||||
WHERE from.id = edge.from_id
|
||||
AND to.id = edge.to_id
|
||||
AND r.relationship_name = edge.relationship_name
|
||||
RETURN from.id, to.id, r.relationship_name
|
||||
"""
|
||||
|
||||
results = await self.query(query, {"edges": edge_params})
|
||||
|
||||
# Convert results back to tuples and ensure string types
|
||||
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
|
||||
|
||||
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
||||
return existing_edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check edges in batch: {e}")
|
||||
return []
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: str,
|
||||
to_node: str,
|
||||
relationship_name: str,
|
||||
edge_properties: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Add an edge between two nodes."""
|
||||
try:
|
||||
query, params = self._edge_query_and_params(
|
||||
from_node, to_node, relationship_name, edge_properties
|
||||
)
|
||||
await self.query(query, params)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add edge: {e}")
|
||||
raise
|
||||
|
||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||
"""Add multiple edges in a batch operation."""
|
||||
if not edges:
|
||||
return
|
||||
|
||||
try:
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Transform edges into the format needed for batch insertion
|
||||
edge_params = [
|
||||
{
|
||||
"from_id": from_node,
|
||||
"to_id": to_node,
|
||||
"relationship_name": relationship_name,
|
||||
"properties": json.dumps(properties, cls=JSONEncoder),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
for from_node, to_node, relationship_name, properties in edges
|
||||
]
|
||||
|
||||
# Batch create query
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from:Node), (to:Node)
|
||||
WHERE from.id = edge.from_id AND to.id = edge.to_id
|
||||
CREATE (from)-[r:EDGE {
|
||||
relationship_name: edge.relationship_name,
|
||||
created_at: timestamp(edge.created_at),
|
||||
updated_at: timestamp(edge.updated_at),
|
||||
properties: edge.properties
|
||||
}]->(to)
|
||||
"""
|
||||
|
||||
await self.query(query, {"edges": edge_params})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add edges in batch: {e}")
|
||||
raise
|
||||
|
||||
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
|
||||
"""Get all edges connected to a node.
|
||||
|
||||
Returns:
|
||||
List of tuples containing (source_node, relationship_name, target_node)
|
||||
where source_node and target_node are dictionaries with node properties,
|
||||
and relationship_name is a string.
|
||||
"""
|
||||
query_str = """
|
||||
MATCH (n:Node)-[r]-(m:Node)
|
||||
WHERE n.id = $node_id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
},
|
||||
r.relationship_name,
|
||||
{
|
||||
id: m.id,
|
||||
text: m.text,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
results = await self.query(query_str, {"node_id": node_id})
|
||||
edges = []
|
||||
for row in results:
|
||||
if row and len(row) == 3:
|
||||
source_node = self._parse_node_properties(row[0])
|
||||
target_node = self._parse_node_properties(row[2])
|
||||
edges.append((source_node, row[1], target_node))
|
||||
return edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||
return []
|
||||
|
||||
# Neighbor Operations
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all neighbouring nodes."""
|
||||
query_str = """
|
||||
MATCH (n)-[r]-(m)
|
||||
WHERE n.id = $id
|
||||
RETURN DISTINCT properties(m)
|
||||
"""
|
||||
try:
|
||||
result = await self.query(query_str, {"id": node_id})
|
||||
return [row[0] for row in result] if result else []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get neighbours for node {node_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_predecessors(
|
||||
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all predecessor nodes."""
|
||||
try:
|
||||
if edge_label:
|
||||
query_str = """
|
||||
MATCH (n)<-[r:EDGE]-(m)
|
||||
WHERE n.id = $id AND r.relationship_name = $edge_label
|
||||
RETURN properties(m)
|
||||
"""
|
||||
params = {"id": str(node_id), "edge_label": edge_label}
|
||||
else:
|
||||
query_str = """
|
||||
MATCH (n)<-[r:EDGE]-(m)
|
||||
WHERE n.id = $id
|
||||
RETURN properties(m)
|
||||
"""
|
||||
params = {"id": str(node_id)}
|
||||
result = await self.query(query_str, params)
|
||||
return [row[0] for row in result] if result else []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get predecessors for node {node_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_successors(
|
||||
self, node_id: Union[str, UUID], edge_label: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all successor nodes."""
|
||||
try:
|
||||
if edge_label:
|
||||
query_str = """
|
||||
MATCH (n)-[r:EDGE]->(m)
|
||||
WHERE n.id = $id AND r.relationship_name = $edge_label
|
||||
RETURN properties(m)
|
||||
"""
|
||||
params = {"id": str(node_id), "edge_label": edge_label}
|
||||
else:
|
||||
query_str = """
|
||||
MATCH (n)-[r:EDGE]->(m)
|
||||
WHERE n.id = $id
|
||||
RETURN properties(m)
|
||||
"""
|
||||
params = {"id": str(node_id)}
|
||||
result = await self.query(query_str, params)
|
||||
return [row[0] for row in result] if result else []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get successors for node {node_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_connections(
|
||||
self, node_id: str
|
||||
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
||||
"""Get all nodes connected to a given node."""
|
||||
query_str = """
|
||||
MATCH (n:Node)-[r:EDGE]-(m:Node)
|
||||
WHERE n.id = $node_id
|
||||
RETURN {
|
||||
id: n.id,
|
||||
text: n.text,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
},
|
||||
{
|
||||
relationship_name: r.relationship_name,
|
||||
properties: r.properties
|
||||
},
|
||||
{
|
||||
id: m.id,
|
||||
text: m.text,
|
||||
type: m.type,
|
||||
properties: m.properties
|
||||
}
|
||||
"""
|
||||
try:
|
||||
results = await self.query(query_str, {"node_id": node_id})
|
||||
edges = []
|
||||
for row in results:
|
||||
if row and len(row) == 3:
|
||||
processed_rows = []
|
||||
for i, item in enumerate(row):
|
||||
if isinstance(item, dict):
|
||||
if "properties" in item and item["properties"]:
|
||||
try:
|
||||
props = json.loads(item["properties"])
|
||||
item.update(props)
|
||||
del item["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON properties for node/edge {i}"
|
||||
)
|
||||
processed_rows.append(item)
|
||||
edges.append(tuple(processed_rows))
|
||||
return edges if edges else [] # Always return a list, even if empty
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get connections for node {node_id}: {e}")
|
||||
return [] # Return empty list on error
|
||||
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: List[str], edge_label: str
|
||||
) -> None:
|
||||
"""Remove all incoming edges of specified type for given nodes."""
|
||||
query_str = """
|
||||
MATCH (n)<-[r:EDGE]-(m)
|
||||
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
|
||||
DELETE r
|
||||
"""
|
||||
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
|
||||
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: List[str], edge_label: str
|
||||
) -> None:
|
||||
"""Remove all outgoing edges of specified type for given nodes."""
|
||||
query_str = """
|
||||
MATCH (n)-[r:EDGE]->(m)
|
||||
WHERE n.id IN $node_ids AND r.relationship_name = $edge_label
|
||||
DELETE r
|
||||
"""
|
||||
await self.query(query_str, {"node_ids": node_ids, "edge_label": edge_label})
|
||||
|
||||
# Graph-wide Operations
|
||||
|
||||
async def get_graph_data(
|
||||
self,
|
||||
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
|
||||
"""Get all nodes and edges in the graph."""
|
||||
try:
|
||||
nodes_query = """
|
||||
MATCH (n:Node)
|
||||
RETURN n.id, {
|
||||
text: n.text,
|
||||
type: n.type,
|
||||
properties: n.properties
|
||||
}
|
||||
"""
|
||||
nodes = await self.query(nodes_query)
|
||||
formatted_nodes = []
|
||||
for n in nodes:
|
||||
if n[0]:
|
||||
node_id = str(n[0])
|
||||
props = n[1]
|
||||
if props.get("properties"):
|
||||
try:
|
||||
additional_props = json.loads(props["properties"])
|
||||
props.update(additional_props)
|
||||
del props["properties"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse properties JSON for node {node_id}")
|
||||
formatted_nodes.append((node_id, props))
|
||||
if not formatted_nodes:
|
||||
logger.warning("No nodes found in the database")
|
||||
return [], []
|
||||
|
||||
edges_query = """
|
||||
MATCH (n:Node)-[r:EDGE]->(m:Node)
|
||||
RETURN n.id, m.id, r.relationship_name, r.properties
|
||||
"""
|
||||
edges = await self.query(edges_query)
|
||||
formatted_edges = []
|
||||
for e in edges:
|
||||
if e and len(e) >= 3:
|
||||
source_id = str(e[0])
|
||||
target_id = str(e[1])
|
||||
rel_type = str(e[2])
|
||||
props = {}
|
||||
if len(e) > 3 and e[3]:
|
||||
try:
|
||||
props = json.loads(e[3])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
f"Failed to parse edge properties for {source_id}->{target_id}"
|
||||
)
|
||||
formatted_edges.append((source_id, target_id, rel_type, props))
|
||||
|
||||
if formatted_nodes and not formatted_edges:
|
||||
logger.debug("No edges found, creating self-referential edges for nodes")
|
||||
for node_id, _ in formatted_nodes:
|
||||
formatted_edges.append(
|
||||
(
|
||||
node_id,
|
||||
node_id,
|
||||
"SELF",
|
||||
{
|
||||
"relationship_name": "SELF",
|
||||
"relationship_type": "SELF",
|
||||
"vector_distance": 0.0,
|
||||
},
|
||||
)
|
||||
)
|
||||
return formatted_nodes, formatted_edges
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph data: {e}")
|
||||
raise
|
||||
|
||||
async def get_filtered_graph_data(
|
||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||
):
|
||||
"""Get filtered nodes and relationships based on attributes."""
|
||||
|
||||
where_clauses = []
|
||||
params = {}
|
||||
|
||||
for i, filter_dict in enumerate(attribute_filters):
|
||||
for attr, values in filter_dict.items():
|
||||
param_name = f"values_{i}_{attr}"
|
||||
where_clauses.append(f"n.{attr} IN ${param_name}")
|
||||
params[param_name] = values
|
||||
|
||||
where_clause = " AND ".join(where_clauses)
|
||||
nodes_query = f"MATCH (n:Node) WHERE {where_clause} RETURN properties(n)"
|
||||
edges_query = f"""
|
||||
MATCH (n1:Node)-[r:EDGE]->(n2:Node)
|
||||
WHERE {where_clause.replace("n.", "n1.")} AND {where_clause.replace("n.", "n2.")}
|
||||
RETURN properties(r)
|
||||
"""
|
||||
nodes, edges = await asyncio.gather(
|
||||
self.query(nodes_query, params), self.query(edges_query, params)
|
||||
)
|
||||
return ([n[0] for n in nodes], [e[0] for e in edges])
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
||||
try:
|
||||
# Basic metrics
|
||||
node_count = await self.query("MATCH (n:Node) RETURN COUNT(n)")
|
||||
edge_count = await self.query("MATCH ()-[r:EDGE]->() RETURN COUNT(r)")
|
||||
num_nodes = node_count[0][0] if node_count else 0
|
||||
num_edges = edge_count[0][0] if edge_count else 0
|
||||
|
||||
# Calculate mandatory metrics
|
||||
mandatory_metrics = {
|
||||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
|
||||
"edge_density": (num_edges) / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
||||
}
|
||||
|
||||
# Calculate connected components
|
||||
components_query = """
|
||||
MATCH (n:Node)
|
||||
WITH n.id AS node_id
|
||||
MATCH path = (n)-[:EDGE*0..]-()
|
||||
WITH COLLECT(DISTINCT node_id) AS component
|
||||
RETURN COLLECT(component) AS components
|
||||
"""
|
||||
components_result = await self.query(components_query)
|
||||
component_sizes = (
|
||||
[len(comp) for comp in components_result[0][0]] if components_result else []
|
||||
)
|
||||
|
||||
mandatory_metrics.update(
|
||||
{
|
||||
"num_connected_components": len(component_sizes),
|
||||
"sizes_of_connected_components": component_sizes,
|
||||
}
|
||||
)
|
||||
|
||||
if include_optional:
|
||||
# Self-loops
|
||||
self_loops_query = """
|
||||
MATCH (n:Node)-[r:EDGE]->(n)
|
||||
RETURN COUNT(r)
|
||||
"""
|
||||
self_loops = await self.query(self_loops_query)
|
||||
num_selfloops = self_loops[0][0] if self_loops else 0
|
||||
|
||||
# Shortest paths (simplified for Kuzu)
|
||||
paths_query = """
|
||||
MATCH (n:Node), (m:Node)
|
||||
WHERE n.id < m.id
|
||||
MATCH path = (n)-[:EDGE*]-(m)
|
||||
RETURN MIN(LENGTH(path)) AS length
|
||||
"""
|
||||
paths = await self.query(paths_query)
|
||||
path_lengths = [p[0] for p in paths if p[0] is not None]
|
||||
|
||||
# Local clustering coefficient
|
||||
clustering_query = """
|
||||
MATCH (n:Node)-[:EDGE]-(neighbor)
|
||||
WITH n, COUNT(DISTINCT neighbor) as degree
|
||||
MATCH (n)-[:EDGE]-(n1)-[:EDGE]-(n2)-[:EDGE]-(n)
|
||||
WHERE n1 <> n2
|
||||
RETURN AVG(CASE WHEN degree <= 1 THEN 0 ELSE COUNT(DISTINCT n2) / (degree * (degree-1)) END)
|
||||
"""
|
||||
clustering = await self.query(clustering_query)
|
||||
|
||||
optional_metrics = {
|
||||
"num_selfloops": num_selfloops,
|
||||
"diameter": max(path_lengths) if path_lengths else -1,
|
||||
"avg_shortest_path_length": sum(path_lengths) / len(path_lengths)
|
||||
if path_lengths
|
||||
else -1,
|
||||
"avg_clustering": clustering[0][0] if clustering and clustering[0][0] else -1,
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
return {**mandatory_metrics, **optional_metrics}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph metrics: {e}")
|
||||
return {
|
||||
"num_nodes": 0,
|
||||
"num_edges": 0,
|
||||
"mean_degree": 0,
|
||||
"edge_density": 0,
|
||||
"num_connected_components": 0,
|
||||
"sizes_of_connected_components": [],
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
async def get_disconnected_nodes(self) -> List[str]:
|
||||
"""Get nodes that are not connected to any other node."""
|
||||
query_str = """
|
||||
MATCH (n:Node)
|
||||
WHERE NOT EXISTS((n)-[]-())
|
||||
RETURN n.id
|
||||
"""
|
||||
result = await self.query(query_str)
|
||||
return [str(row[0]) for row in result]
|
||||
|
||||
# Graph Meta-Data Operations
|
||||
|
||||
async def get_model_independent_graph_data(self) -> Dict[str, List[str]]:
|
||||
"""Get graph data independent of any specific data model."""
|
||||
node_labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
|
||||
rel_types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
|
||||
return {
|
||||
"node_labels": [label[0] for label in node_labels],
|
||||
"relationship_types": [rel[0] for rel in rel_types],
|
||||
}
|
||||
|
||||
async def get_node_labels_string(self) -> str:
|
||||
"""Get all node labels as a string."""
|
||||
labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
|
||||
return "|".join(sorted(set([label[0] for label in labels])))
|
||||
|
||||
async def get_relationship_labels_string(self) -> str:
|
||||
"""Get all relationship types as a string."""
|
||||
types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
|
||||
return "|".join(sorted(set([t[0] for t in types])))
|
||||
|
||||
async def delete_graph(self) -> None:
|
||||
"""Delete all data from the graph while preserving the database structure."""
|
||||
try:
|
||||
# Delete relationships from the fixed table EDGE
|
||||
await self.query("MATCH ()-[r:EDGE]->() DELETE r")
|
||||
# Then delete nodes
|
||||
await self.query("MATCH (n:Node) DELETE n")
|
||||
logger.info("Cleared all data from graph while preserving structure")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete graph data: {e}")
|
||||
raise
|
||||
|
||||
async def clear_database(self) -> None:
|
||||
"""Clear all data from the database by deleting the database files and reinitializing."""
|
||||
try:
|
||||
if self.connection:
|
||||
self.connection = None
|
||||
if self.db:
|
||||
self.db.close()
|
||||
self.db = None
|
||||
if os.path.exists(self.db_path):
|
||||
shutil.rmtree(self.db_path)
|
||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||
# Reinitialize the database
|
||||
self._initialize_connection()
|
||||
# Verify the database is empty
|
||||
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||
count = result.get_next()[0] if result.has_next() else 0
|
||||
if count > 0:
|
||||
logger.warning(
|
||||
f"Database still contains {count} nodes after clearing, forcing deletion"
|
||||
)
|
||||
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
|
||||
logger.info("Database cleared successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during database clearing: {e}")
|
||||
raise
|
||||
|
||||
async def save_graph_to_file(self, file_path: str) -> None:
|
||||
"""Export the Kuzu database to a file.
|
||||
|
||||
Args:
|
||||
file_path: Path where to export the database
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# Use Kuzu's native EXPORT command, output is Parquet
|
||||
export_query = f"EXPORT DATABASE '{file_path}'"
|
||||
await self.query(export_query)
|
||||
|
||||
logger.info(f"Graph exported to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export graph to file: {e}")
|
||||
raise
|
||||
|
||||
async def load_graph_from_file(self, file_path: str) -> None:
|
||||
"""Import a Kuzu database from a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the exported database file
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"File {file_path} not found")
|
||||
return
|
||||
|
||||
# Use Kuzu's native IMPORT command
|
||||
import_query = f"IMPORT DATABASE '{file_path}'"
|
||||
await self.query(import_query)
|
||||
|
||||
logger.info(f"Graph imported from {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import graph from file: {e}")
|
||||
raise
|
||||
|
|
@ -293,6 +293,7 @@ class GraphDBType(Enum):
|
|||
NETWORKX = auto()
|
||||
NEO4J = auto()
|
||||
FALKORDB = auto()
|
||||
KUZU = auto()
|
||||
|
||||
|
||||
# Models for representing different entities
|
||||
|
|
|
|||
107
cognee/tests/test_kuzu.py
Normal file
107
cognee/tests/test_kuzu.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
import shutil
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from uuid import uuid4
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
async def main():
|
||||
# Clean up test directories before starting
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_kuzu")
|
||||
).resolve()
|
||||
)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_kuzu")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
try:
|
||||
# Set Kuzu as the graph database provider
|
||||
cognee.config.set_graph_database_provider("kuzu")
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "cs_explanations"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
||||
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
|
||||
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
|
||||
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
|
||||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
||||
"""
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.CHUNKS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted summaries are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
history = await cognee.get_search_history()
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Kuzu graph database is not empty"
|
||||
|
||||
finally:
|
||||
# Ensure cleanup even if tests fail
|
||||
for path in [data_directory_path, cognee_directory_path]:
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -3,7 +3,7 @@ FROM python:3.11-slim
|
|||
# Define Poetry extras to install
|
||||
ARG POETRY_EXTRAS="\
|
||||
# Storage & Databases \
|
||||
filesystem postgres weaviate qdrant neo4j falkordb milvus \
|
||||
filesystem postgres weaviate qdrant neo4j falkordb milvus kuzu \
|
||||
# Notebooks & Interactive Environments \
|
||||
notebook \
|
||||
# LLM & AI Frameworks \
|
||||
|
|
|
|||
823
poetry.lock
generated
823
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -31,6 +31,7 @@ nest_asyncio = "1.6.0"
|
|||
numpy = "1.26.4"
|
||||
datasets = "3.1.0"
|
||||
falkordb = {version = "1.0.9", optional = true}
|
||||
kuzu = {version = "0.8.2", optional = true}
|
||||
boto3 = "^1.26.125"
|
||||
botocore="^1.35.54"
|
||||
gunicorn = "^20.1.0"
|
||||
|
|
@ -107,6 +108,7 @@ mistral = ["mistral-common"]
|
|||
deepeval = ["deepeval"]
|
||||
posthog = ["posthog"]
|
||||
falkordb = ["falkordb"]
|
||||
kuzu = ["kuzu"]
|
||||
groq = ["groq"]
|
||||
milvus = ["pymilvus"]
|
||||
docs = ["unstructured"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue