refactor: Create new abstraction for dataset database mapping and handling
This commit is contained in:
parent
0800810713
commit
64a3ee96c4
16 changed files with 300 additions and 188 deletions
|
|
@ -27,6 +27,7 @@ async def set_session_user_context_variable(user):
|
|||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
# TODO: Make sure dataset database handler and provider match, remove multi_user support check, add error if no dataset database handler exists for provider
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT
|
||||
and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface
|
||||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
from .use_dataset_database_handler import use_dataset_database_handler
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph or vector database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_id: UUID, user: User) -> None:
|
||||
"""
|
||||
Delete the graph or vector database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user: User object
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDatasetDatabaseHandler import (
|
||||
Neo4jAuraDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import (
|
||||
LanceDBDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura": Neo4jAuraDatasetDatabaseHandler,
|
||||
"lancedb": LanceDBDatasetDatabaseHandler,
|
||||
"kuzu": KuzuDatasetDatabaseHandler,
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from .supported_dataset_database_handlers import supported_dataset_database_handlers
|
||||
|
||||
|
||||
def use_dataset_database_handler(dataset_database_handler_name, dataset_database_handler):
|
||||
supported_dataset_database_handlers[dataset_database_handler_name] = dataset_database_handler
|
||||
|
|
@ -47,6 +47,7 @@ class GraphConfig(BaseSettings):
|
|||
graph_filename: str = ""
|
||||
graph_model: object = KnowledgeGraph
|
||||
graph_topology: object = KnowledgeGraph
|
||||
graph_dataset_database_handler: str = "kuzu"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True)
|
||||
|
||||
# Model validator updates graph_filename and path dynamically after class creation based on current database provider
|
||||
|
|
@ -97,6 +98,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
|
|
@ -121,6 +123,7 @@ class GraphConfig(BaseSettings):
|
|||
"graph_database_port": self.graph_database_port,
|
||||
"graph_database_key": self.graph_database_key,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_dataset_database_handler": self.graph_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def create_graph_engine(
|
|||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Optional, Dict, Any, List, Tuple, Type, Union
|
|||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.data.models.graph_relationship_ledger import GraphRelationshipLedger
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import get_relational_engine
|
||||
|
||||
|
|
@ -399,36 +398,3 @@ class GraphDBInterface(ABC):
|
|||
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a graph database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
||||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
Each dataset needs to map to a unique graph database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset if needed by the database creation logic
|
||||
user: User object if needed by the database creation logic
|
||||
Returns:
|
||||
dict: Connection info for the created graph database instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
|
||||
"""
|
||||
Delete the graph database for the given dataset.
|
||||
Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
|
||||
Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
|
||||
|
||||
Args:
|
||||
dataset_id: UUID of the dataset
|
||||
user: User object
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Kuzu Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Kuzu instance
|
||||
|
||||
"""
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "kuzu":
|
||||
raise ValueError(
|
||||
"KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider."
|
||||
)
|
||||
|
||||
# TODO: Add graph file path info for kuzu (also in DatasetDatabase model)
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
||||
pass
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with Neo4j Aura Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_config = get_graph_config()
|
||||
|
||||
if graph_config.graph_database_provider != "neo4j":
|
||||
raise ValueError(
|
||||
"Neo4jAuraDatasetDatabaseHandler can only be used with Neo4j graph database provider."
|
||||
)
|
||||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||
)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(status_url, headers=headers)
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
||||
pass
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
"""Neo4j Adapter for Graph Database"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import requests
|
||||
from uuid import UUID
|
||||
from textwrap import dedent
|
||||
from neo4j import AsyncSession
|
||||
|
|
@ -14,7 +12,6 @@ from typing import Optional, Any, List, Dict, Type, Tuple
|
|||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.temporal_graph.models import Timestamp
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
|
|
@ -1473,89 +1470,3 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
user: User object who owns the dataset and is making the request
|
||||
|
||||
Returns:
|
||||
dict: Connection details for the created Neo4j instance
|
||||
|
||||
"""
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
|
||||
# Make the request with HTTP Basic Auth
|
||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
||||
url = "https://api.neo4j.io/oauth/token"
|
||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||
|
||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
||||
resp.raise_for_status() # raises if the request failed
|
||||
return resp.json()
|
||||
|
||||
resp = get_aura_token(client_id, client_secret)
|
||||
|
||||
url = "https://api.neo4j.io/v1/instances"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"Authorization": f"Bearer {resp['access_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# TODO: Maybe we can allow **kwargs parameter forwarding for cases like these
|
||||
# Too allow different configurations between datasets
|
||||
payload = {
|
||||
"version": "5",
|
||||
"region": "europe-west1",
|
||||
"memory": "1GB",
|
||||
"name": graph_db_name[
|
||||
0:29
|
||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||
"type": "professional-db",
|
||||
"tenant_id": tenant_id,
|
||||
"cloud_provider": "gcp",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
graph_db_name = "neo4j"
|
||||
graph_db_url = response.json()["data"]["connection_url"]
|
||||
graph_db_key = resp["access_token"]
|
||||
graph_db_username = response.json()["data"]["username"]
|
||||
graph_db_password = response.json()["data"]["password"]
|
||||
|
||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||
# Poll until the instance is running
|
||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(status_url, headers=headers)
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
await asyncio.sleep(10)
|
||||
raise TimeoutError(
|
||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||
)
|
||||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,61 +20,23 @@ from cognee.modules.users.models import User
|
|||
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
# Determine vector configuration
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
# TODO: Have the create_database method be called from interface adapter automatically for all providers instead of specifically here
|
||||
from cognee.infrastructure.databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
return await LanceDBAdapter.create_dataset(dataset_id, user)
|
||||
|
||||
else:
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
vector_db_name = vector_config.vector_db_name
|
||||
vector_db_url = vector_config.vector_database_url
|
||||
|
||||
return {
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_url": vector_db_url,
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
}
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler.create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
|
||||
graph_config = get_graph_config()
|
||||
# Determine graph database URL
|
||||
if graph_config.graph_database_provider == "neo4j":
|
||||
from cognee.infrastructure.databases.graph.neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
return await Neo4jAdapter.create_dataset(dataset_id, user)
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
elif graph_config.graph_database_provider == "kuzu":
|
||||
# TODO: Add graph file path info for kuzu (also in DatasetDatabase model)
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
elif graph_config.graph_database_provider == "falkor":
|
||||
# Note: for hybrid databases both graph and vector DB name have to be the same
|
||||
graph_db_name = f"{dataset_id}"
|
||||
graph_db_url = graph_config.graph_database_url
|
||||
graph_db_key = graph_config.graph_database_key
|
||||
graph_db_username = graph_config.graph_database_username
|
||||
graph_db_password = graph_config.graph_database_password
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider for backend access control: {graph_config.graph_database_provider}"
|
||||
)
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
}
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler.create_dataset(dataset_id, user)
|
||||
|
||||
|
||||
async def _existing_dataset_database(
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
|
|||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
vector_dataset_database_handler: str = "lancedb"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ def create_vector_engine(
|
|||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
|
|||
|
|
@ -362,20 +362,3 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
},
|
||||
exclude_fields=["metadata"] + related_models_fields,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
base_config = get_base_config()
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with LanceDB Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
base_config = get_base_config()
|
||||
|
||||
if vector_config.vector_db_provider != "lancedb":
|
||||
raise ValueError(
|
||||
"LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
|
||||
)
|
||||
|
||||
databases_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
|
||||
return {
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_url": os.path.join(databases_directory_path, vector_db_name),
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_key": vector_config.vector_db_key,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
||||
pass
|
||||
Loading…
Add table
Reference in a new issue