refactor: Create new abstraction for dataset database mapping and handling

This commit is contained in:
Igor Ilic 2025-11-24 20:31:28 +01:00
parent 0800810713
commit 64a3ee96c4
16 changed files with 300 additions and 188 deletions

View file

@ -27,6 +27,7 @@ async def set_session_user_context_variable(user):
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
# TODO: Make sure dataset database handler and provider match, remove multi_user support check, add error if no dataset database handler exists for provider
return (
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT

View file

@ -0,0 +1,3 @@
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
from .supported_dataset_database_handlers import supported_dataset_database_handlers
from .use_dataset_database_handler import use_dataset_database_handler

View file

@ -0,0 +1,43 @@
from typing import Optional
from uuid import UUID
from abc import ABC, abstractmethod
from cognee.modules.users.models.User import User
class DatasetDatabaseHandlerInterface(ABC):
@classmethod
@abstractmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with connection info for a graph or vector database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created graph or vector database instance.
"""
pass
@classmethod
@abstractmethod
async def delete_dataset(cls, dataset_id: UUID, user: User) -> None:
"""
Delete the graph or vector database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_id: UUID of the dataset
user: User object
"""
pass

View file

@ -0,0 +1,15 @@
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDatasetDatabaseHandler import (
Neo4jAuraDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
LanceDBDatasetDatabaseHandler,
)
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
KuzuDatasetDatabaseHandler,
)
supported_dataset_database_handlers = {
"neo4j_aura": Neo4jAuraDatasetDatabaseHandler,
"lancedb": LanceDBDatasetDatabaseHandler,
"kuzu": KuzuDatasetDatabaseHandler,
}

View file

@ -0,0 +1,5 @@
from .supported_dataset_database_handlers import supported_dataset_database_handlers
def use_dataset_database_handler(dataset_database_handler_name, dataset_database_handler):
supported_dataset_database_handlers[dataset_database_handler_name] = dataset_database_handler

View file

@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
graph_filename: str = ""
graph_model: object = KnowledgeGraph
graph_topology: object = KnowledgeGraph
graph_dataset_database_handler: str = "kuzu"
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}
def to_hashable_dict(self) -> dict:
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
"graph_database_port": self.graph_database_port,
"graph_database_key": self.graph_database_key,
"graph_file_path": self.graph_file_path,
"graph_dataset_database_handler": self.graph_dataset_database_handler,
}

View file

@ -34,6 +34,7 @@ def create_graph_engine(
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -6,7 +6,6 @@ from typing import Optional, Dict, Any, List, Tuple, Type, Union
from uuid import NAMESPACE_OID, UUID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.modules.users.models.User import User
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
@ -399,36 +398,3 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Return a dictionary with connection info for a graph database for the given dataset.
Function can auto handle deploying of the actual database if needed, but is not necessary.
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
Needed for Cognee multi-tenant/multi-user and backend access control support.
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
From which internal mapping of dataset -> database connection info will be done.
Each dataset needs to map to a unique graph database when backend access control is enabled to facilitate a separation of concern for data.
Args:
dataset_id: UUID of the dataset if needed by the database creation logic
user: User object if needed by the database creation logic
Returns:
dict: Connection info for the created graph database instance.
"""
pass
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
"""
Delete the graph database for the given dataset.
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
Args:
dataset_id: UUID of the dataset
user: User object
"""
pass

View file

@ -0,0 +1,57 @@
import os
import asyncio
import requests
from uuid import UUID
from typing import Optional
from cognee.modules.users.models import User
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with Kuzu Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Kuzu instance
"""
from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
if graph_config.graph_database_provider != "kuzu":
raise ValueError(
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
)
# TODO: Add graph file path info for kuzu (also in DatasetDatabase model)
graph_db_name = f"{dataset_id}.pkl"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": graph_config.graph_database_provider,
"graph_database_key": graph_db_key,
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
}
@classmethod
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
pass

View file

@ -0,0 +1,118 @@
import os
import asyncio
import requests
from uuid import UUID
from typing import Optional
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with Neo4j Aura Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Neo4j instance
"""
graph_config = get_graph_config()
if graph_config.graph_database_provider != "neo4j":
raise ValueError(
"Neo4jAuraDatasetDatabaseHandler can only be used with Neo4j graph database provider."
)
graph_db_name = f"{dataset_id}"
# Client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
if client_id is None or client_secret is None or tenant_id is None:
raise ValueError(
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
# Too allow different configurations between datasets
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[
0:29
], # TODO: Find better name to name Neo4j instance within 30 character limit
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(status_url, headers=headers)
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": "neo4j",
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
}
@classmethod
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
pass

View file

@ -1,9 +1,7 @@
"""Neo4j Adapter for Graph Database"""
import os
import json
import asyncio
import requests
from uuid import UUID
from textwrap import dedent
from neo4j import AsyncSession
@ -14,7 +12,6 @@ from typing import Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.modules.users.models import User
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
@ -1473,89 +1470,3 @@ class Neo4jAdapter(GraphDBInterface):
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
return ", ".join(f"'{uid}'" for uid in time_ids_list)
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
"""
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
Args:
dataset_id: Dataset UUID
user: User object who owns the dataset and is making the request
Returns:
dict: Connection details for the created Neo4j instance
"""
graph_db_name = f"{dataset_id}"
# Client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
# Too allow different configurations between datasets
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[
0:29
], # TODO: Find better name to name Neo4j instance within 30 character limit
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
graph_db_name = "neo4j"
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(status_url, headers=headers)
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": "neo4j",
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
}

View file

@ -20,61 +20,23 @@ from cognee.modules.users.models import User
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
vector_config = get_vectordb_config()
# Determine vector configuration
if vector_config.vector_db_provider == "lancedb":
# TODO: Have the create_database method be called from interface adapter automatically for all providers instead of specifically here
from cognee.infrastructure.databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
return await LanceDBAdapter.create_dataset(dataset_id, user)
else:
# Note: for hybrid databases both graph and vector DB name have to be the same
vector_db_name = vector_config.vector_db_name
vector_db_url = vector_config.vector_database_url
return {
"vector_database_name": vector_db_name,
"vector_database_url": vector_db_url,
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_key": vector_config.vector_db_key,
}
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
return await handler.create_dataset(dataset_id, user)
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
graph_config = get_graph_config()
# Determine graph database URL
if graph_config.graph_database_provider == "neo4j":
from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter
return await Neo4jAdapter.create_dataset(dataset_id, user)
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
supported_dataset_database_handlers,
)
elif graph_config.graph_database_provider == "kuzu":
# TODO: Add graph file path info for kuzu (also in DatasetDatabase model)
graph_db_name = f"{dataset_id}.pkl"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
elif graph_config.graph_database_provider == "falkor":
# Note: for hybrid databases both graph and vector DB name have to be the same
graph_db_name = f"{dataset_id}"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
else:
raise EnvironmentError(
f"Unsupported graph database provider for backend access control: {graph_config.graph_database_provider}"
)
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": graph_config.graph_database_provider,
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
}
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
return await handler.create_dataset(dataset_id, user)
async def _existing_dataset_database(

View file

@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
vector_dataset_database_handler: str = "lancedb"
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
"vector_dataset_database_handler": self.vector_dataset_database_handler,
}

View file

@ -12,6 +12,7 @@ def create_vector_engine(
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -362,20 +362,3 @@ class LanceDBAdapter(VectorDBInterface):
},
exclude_fields=["metadata"] + related_models_fields,
)
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
vector_db_name = f"{dataset_id}.lance.db"
return {
"vector_database_name": vector_db_name,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_key": vector_config.vector_db_key,
}

View file

@ -0,0 +1,41 @@
import os
from uuid import UUID
from typing import Optional
from cognee.modules.users.models import User
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"""
Handler for interacting with LanceDB Dataset databases.
"""
@classmethod
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
if vector_config.vector_db_provider != "lancedb":
raise ValueError(
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
)
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
vector_db_name = f"{dataset_id}.lance.db"
return {
"vector_database_name": vector_db_name,
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_key": vector_config.vector_db_key,
}
@classmethod
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
pass