diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 44ead95af..2b6ffa058 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -27,6 +27,7 @@ async def set_session_user_context_variable(user): def multi_user_support_possible(): graph_db_config = get_graph_context_config() vector_db_config = get_vectordb_context_config() + # TODO: Make sure dataset database handler and provider match, remove multi_user support check, add error if no dataset database handler exists for provider return ( graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT diff --git a/cognee/infrastructure/databases/dataset_database_handler/__init__.py b/cognee/infrastructure/databases/dataset_database_handler/__init__.py new file mode 100644 index 000000000..a74017113 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/__init__.py @@ -0,0 +1,3 @@ +from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface +from .supported_dataset_database_handlers import supported_dataset_database_handlers +from .use_dataset_database_handler import use_dataset_database_handler diff --git a/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py new file mode 100644 index 000000000..6dadee6cf --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py @@ -0,0 +1,43 @@ +from typing import Optional +from uuid import UUID +from abc import ABC, abstractmethod + +from cognee.modules.users.models.User import User + + +class DatasetDatabaseHandlerInterface(ABC): + @classmethod + @abstractmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Return a dictionary with connection info for a graph or vector database for the given dataset. + Function can auto handle deploying of the actual database if needed, but is not necessary. + Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. + Needed for Cognee multi-tenant/multi-user and backend access control support. + + Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. + From which internal mapping of dataset -> database connection info will be done. + + Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data. + + Args: + dataset_id: UUID of the dataset if needed by the database creation logic + user: User object if needed by the database creation logic + Returns: + dict: Connection info for the created graph or vector database instance. + """ + pass + + @classmethod + @abstractmethod + async def delete_dataset(cls, dataset_id: UUID, user: User) -> None: + """ + Delete the graph or vector database for the given dataset. + Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset. + Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. + + Args: + dataset_id: UUID of the dataset + user: User object + """ + pass diff --git a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py new file mode 100644 index 000000000..9cc7d9f93 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py @@ -0,0 +1,15 @@ +from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDatasetDatabaseHandler import ( + Neo4jAuraDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import ( + LanceDBDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import ( + KuzuDatasetDatabaseHandler, +) + +supported_dataset_database_handlers = { + "neo4j_aura": Neo4jAuraDatasetDatabaseHandler, + "lancedb": LanceDBDatasetDatabaseHandler, + "kuzu": KuzuDatasetDatabaseHandler, +} diff --git a/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py new file mode 100644 index 000000000..a583de354 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py @@ -0,0 +1,5 @@ +from .supported_dataset_database_handlers import supported_dataset_database_handlers + + +def use_dataset_database_handler(dataset_database_handler_name, dataset_database_handler): + supported_dataset_database_handlers[dataset_database_handler_name] = dataset_database_handler diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index 23687b359..bcf97ebfa 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -47,6 +47,7 @@ class GraphConfig(BaseSettings): graph_filename: str = "" graph_model: object = KnowledgeGraph graph_topology: object = KnowledgeGraph + graph_dataset_database_handler: str = "kuzu" model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True) # Model validator updates graph_filename and path dynamically after class creation based on current database provider @@ -97,6 +98,7 @@ class GraphConfig(BaseSettings): "graph_model": self.graph_model, "graph_topology": self.graph_topology, "model_config": self.model_config, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } def to_hashable_dict(self) -> dict: @@ -121,6 +123,7 @@ class GraphConfig(BaseSettings): "graph_database_port": self.graph_database_port, "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 82e3cad6e..c37af2102 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -34,6 +34,7 @@ def create_graph_engine( graph_database_password="", graph_database_port="", graph_database_key="", + graph_dataset_database_handler="", ): """ Create a graph engine based on the specified provider type. diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 6d323764b..67df1a27c 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -6,7 +6,6 @@ from typing import Optional, Dict, Any, List, Tuple, Type, Union from uuid import NAMESPACE_OID, UUID, uuid5 from cognee.shared.logging_utils import get_logger from cognee.infrastructure.engine import DataPoint -from cognee.modules.users.models.User import User from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine @@ -399,36 +398,3 @@ class GraphDBInterface(ABC): - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError - - @classmethod - async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: - """ - Return a dictionary with connection info for a graph database for the given dataset. - Function can auto handle deploying of the actual database if needed, but is not necessary. - Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. - Needed for Cognee multi-tenant/multi-user and backend access control support. - - Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. - From which internal mapping of dataset -> database connection info will be done. - - Each dataset needs to map to a unique graph database when backend access control is enabled to facilitate a separation of concern for data. - - Args: - dataset_id: UUID of the dataset if needed by the database creation logic - user: User object if needed by the database creation logic - Returns: - dict: Connection info for the created graph database instance. - """ - pass - - async def delete_dataset(self, dataset_id: UUID, user: User) -> None: - """ - Delete the graph database for the given dataset. - Function should auto handle deleting of the actual database or send a request to the proper service to delete the database. - Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. - - Args: - dataset_id: UUID of the dataset - user: User object - """ - pass diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py new file mode 100644 index 000000000..8859422f9 --- /dev/null +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -0,0 +1,57 @@ +import os +import asyncio +import requests +from uuid import UUID +from typing import Optional + +from cognee.modules.users.models import User + +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with Kuzu Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Kuzu instance + + """ + from cognee.infrastructure.databases.graph.config import get_graph_config + + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "kuzu": + raise ValueError( + "KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider." + ) + + # TODO: Add graph file path info for kuzu (also in DatasetDatabase model) + graph_db_name = f"{dataset_id}.pkl" + graph_db_url = graph_config.graph_database_url + graph_db_key = graph_config.graph_database_key + graph_db_username = graph_config.graph_database_username + graph_db_password = graph_config.graph_database_password + + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": graph_config.graph_database_provider, + "graph_database_key": graph_db_key, + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + } + + @classmethod + async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]): + pass diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py new file mode 100644 index 000000000..cc38abed0 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py @@ -0,0 +1,118 @@ +import os +import asyncio +import requests +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.graph import get_graph_config +from cognee.modules.users.models import User + +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with Neo4j Aura Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Neo4j instance + + """ + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "neo4j": + raise ValueError( + "Neo4jAuraDatasetDatabaseHandler can only be used with Neo4j graph database provider." + ) + + graph_db_name = f"{dataset_id}" + + # Client credentials + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + tenant_id = os.environ.get("NEO4J_TENANT_ID", None) + + if client_id is None or client_secret is None or tenant_id is None: + raise ValueError( + "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." + ) + + # Make the request with HTTP Basic Auth + def get_aura_token(client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + resp = requests.post(url, data=data, auth=(client_id, client_secret)) + resp.raise_for_status() # raises if the request failed + return resp.json() + + resp = get_aura_token(client_id, client_secret) + + url = "https://api.neo4j.io/v1/instances" + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + # TODO: Maybe we can allow **kwargs parameter forwarding for cases like these + # Too allow different configurations between datasets + payload = { + "version": "5", + "region": "europe-west1", + "memory": "1GB", + "name": graph_db_name[ + 0:29 + ], # TODO: Find better name to name Neo4j instance within 30 character limit + "type": "professional-db", + "tenant_id": tenant_id, + "cloud_provider": "gcp", + } + + response = requests.post(url, headers=headers, json=payload) + + graph_db_name = "neo4j" # Has to be 'neo4j' for Aura + graph_db_url = response.json()["data"]["connection_url"] + graph_db_key = resp["access_token"] + graph_db_username = response.json()["data"]["username"] + graph_db_password = response.json()["data"]["password"] + + async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): + # Poll until the instance is running + status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" + status = "" + for attempt in range(30): # Try for up to ~5 minutes + status_resp = requests.get(status_url, headers=headers) + status = status_resp.json()["data"]["status"] + if status.lower() == "running": + return + await asyncio.sleep(10) + raise TimeoutError( + f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" + ) + + instance_id = response.json()["data"]["id"] + await _wait_for_neo4j_instance_provisioning(instance_id, headers) + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": "neo4j", + "graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + } + + @classmethod + async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]): + pass diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 43e5ea654..6216e107e 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1,9 +1,7 @@ """Neo4j Adapter for Graph Database""" -import os import json import asyncio -import requests from uuid import UUID from textwrap import dedent from neo4j import AsyncSession @@ -14,7 +12,6 @@ from typing import Optional, Any, List, Dict, Type, Tuple from cognee.infrastructure.engine import DataPoint from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int -from cognee.modules.users.models import User from cognee.tasks.temporal_graph.models import Timestamp from cognee.shared.logging_utils import get_logger, ERROR from cognee.infrastructure.databases.graph.graph_db_interface import ( @@ -1473,89 +1470,3 @@ class Neo4jAdapter(GraphDBInterface): time_ids_list = [item["id"] for item in time_nodes if "id" in item] return ", ".join(f"'{uid}'" for uid in time_ids_list) - - @classmethod - async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: - """ - Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset. - - Args: - dataset_id: Dataset UUID - user: User object who owns the dataset and is making the request - - Returns: - dict: Connection details for the created Neo4j instance - - """ - graph_db_name = f"{dataset_id}" - - # Client credentials - client_id = os.environ.get("NEO4J_CLIENT_ID", None) - client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) - tenant_id = os.environ.get("NEO4J_TENANT_ID", None) - - # Make the request with HTTP Basic Auth - def get_aura_token(client_id: str, client_secret: str) -> dict: - url = "https://api.neo4j.io/oauth/token" - data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded - - resp = requests.post(url, data=data, auth=(client_id, client_secret)) - resp.raise_for_status() # raises if the request failed - return resp.json() - - resp = get_aura_token(client_id, client_secret) - - url = "https://api.neo4j.io/v1/instances" - - headers = { - "accept": "application/json", - "Authorization": f"Bearer {resp['access_token']}", - "Content-Type": "application/json", - } - - # TODO: Maybe we can allow **kwargs parameter forwarding for cases like these - # Too allow different configurations between datasets - payload = { - "version": "5", - "region": "europe-west1", - "memory": "1GB", - "name": graph_db_name[ - 0:29 - ], # TODO: Find better name to name Neo4j instance within 30 character limit - "type": "professional-db", - "tenant_id": tenant_id, - "cloud_provider": "gcp", - } - - response = requests.post(url, headers=headers, json=payload) - - graph_db_name = "neo4j" - graph_db_url = response.json()["data"]["connection_url"] - graph_db_key = resp["access_token"] - graph_db_username = response.json()["data"]["username"] - graph_db_password = response.json()["data"]["password"] - - async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): - # Poll until the instance is running - status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" - status = "" - for attempt in range(30): # Try for up to ~5 minutes - status_resp = requests.get(status_url, headers=headers) - status = status_resp.json()["data"]["status"] - if status.lower() == "running": - return - await asyncio.sleep(10) - raise TimeoutError( - f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" - ) - - instance_id = response.json()["data"]["id"] - await _wait_for_neo4j_instance_provisioning(instance_id, headers) - return { - "graph_database_name": graph_db_name, - "graph_database_url": graph_db_url, - "graph_database_provider": "neo4j", - "graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB - "graph_database_username": graph_db_username, - "graph_database_password": graph_db_password, - } diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index b60640d4c..f4bacca7e 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -20,61 +20,23 @@ from cognee.modules.users.models import User async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict: vector_config = get_vectordb_config() - # Determine vector configuration - if vector_config.vector_db_provider == "lancedb": - # TODO: Have the create_database method be called from interface adapter automatically for all providers instead of specifically here - from cognee.infrastructure.databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) - return await LanceDBAdapter.create_dataset(dataset_id, user) - - else: - # Note: for hybrid databases both graph and vector DB name have to be the same - vector_db_name = vector_config.vector_db_name - vector_db_url = vector_config.vector_database_url - - return { - "vector_database_name": vector_db_name, - "vector_database_url": vector_db_url, - "vector_database_provider": vector_config.vector_db_provider, - "vector_database_key": vector_config.vector_db_key, - } + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler.create_dataset(dataset_id, user) async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict: graph_config = get_graph_config() - # Determine graph database URL - if graph_config.graph_database_provider == "neo4j": - from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter - return await Neo4jAdapter.create_dataset(dataset_id, user) + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) - elif graph_config.graph_database_provider == "kuzu": - # TODO: Add graph file path info for kuzu (also in DatasetDatabase model) - graph_db_name = f"{dataset_id}.pkl" - graph_db_url = graph_config.graph_database_url - graph_db_key = graph_config.graph_database_key - graph_db_username = graph_config.graph_database_username - graph_db_password = graph_config.graph_database_password - elif graph_config.graph_database_provider == "falkor": - # Note: for hybrid databases both graph and vector DB name have to be the same - graph_db_name = f"{dataset_id}" - graph_db_url = graph_config.graph_database_url - graph_db_key = graph_config.graph_database_key - graph_db_username = graph_config.graph_database_username - graph_db_password = graph_config.graph_database_password - else: - raise EnvironmentError( - f"Unsupported graph database provider for backend access control: {graph_config.graph_database_provider}" - ) - - return { - "graph_database_name": graph_db_name, - "graph_database_url": graph_db_url, - "graph_database_provider": graph_config.graph_database_provider, - "graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB - "graph_database_username": graph_db_username, - "graph_database_password": graph_db_password, - } + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler.create_dataset(dataset_id, user) async def _existing_dataset_database( diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 7d28f1668..86b2a0fce 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -28,6 +28,7 @@ class VectorConfig(BaseSettings): vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" + vector_dataset_database_handler: str = "lancedb" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -63,6 +64,7 @@ class VectorConfig(BaseSettings): "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, + "vector_dataset_database_handler": self.vector_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index b182f084b..02e01e288 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -12,6 +12,7 @@ def create_vector_engine( vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", + vector_dataset_database_handler: str = "", ): """ Create a vector database engine based on the specified provider. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index a93fbc818..b52f78517 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -362,20 +362,3 @@ class LanceDBAdapter(VectorDBInterface): }, exclude_fields=["metadata"] + related_models_fields, ) - - @classmethod - async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: - vector_config = get_vectordb_config() - base_config = get_base_config() - databases_directory_path = os.path.join( - base_config.system_root_directory, "databases", str(user.id) - ) - - vector_db_name = f"{dataset_id}.lance.db" - - return { - "vector_database_name": vector_db_name, - "vector_database_url": os.path.join(databases_directory_path, vector_db_name), - "vector_database_provider": vector_config.vector_db_provider, - "vector_database_key": vector_config.vector_db_key, - } diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py new file mode 100644 index 000000000..8a80dddcf --- /dev/null +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py @@ -0,0 +1,41 @@ +import os +from uuid import UUID +from typing import Optional + +from cognee.modules.users.models import User +from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with LanceDB Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + vector_config = get_vectordb_config() + base_config = get_base_config() + + if vector_config.vector_db_provider != "lancedb": + raise ValueError( + "LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider." + ) + + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + vector_db_name = f"{dataset_id}.lance.db" + + return { + "vector_database_name": vector_db_name, + "vector_database_url": os.path.join(databases_directory_path, vector_db_name), + "vector_database_provider": vector_config.vector_db_provider, + "vector_database_key": vector_config.vector_db_key, + } + + @classmethod + async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]): + pass