diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index d52de4b4e..9a4f49763 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -57,19 +57,34 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # Set vector and graph database configuration based on dataset database information vector_config = { - "vector_db_url": os.path.join( - databases_directory_path, dataset_database.vector_database_name - ), - "vector_db_key": "", - "vector_db_provider": "lancedb", + "vector_db_provider": dataset_database.vector_database_provider, + "vector_db_url": dataset_database.vector_database_url, + # TODO: Maybe add key to dataset_database, and put it here?? + "vector_db_key": "" } + # vector_config = { + # "vector_db_url": os.path.join( + # databases_directory_path, dataset_database.vector_database_name + # ), + # "vector_db_key": "", + # "vector_db_provider": "lancedb", + # } + graph_config = { - "graph_database_provider": "kuzu", + "graph_database_provider": dataset_database.graph_database_provider, + "graph_database_url": dataset_database.graph_database_url, + "graph_database_name": dataset_database.graph_database_name, "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), } + # graph_config = { + # "graph_database_provider": "kuzu", + # "graph_file_path": os.path.join( + # databases_directory_path, dataset_database.graph_database_name + # ), + # } storage_config = { "data_root_directory": data_root_directory, diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 1ea61d29f..217f63070 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -69,6 +69,7 @@ def create_graph_engine( graph_database_url=graph_database_url, graph_database_username=graph_database_username, graph_database_password=graph_database_password, + graph_name=graph_database_name, ) if graph_database_provider == "neo4j": diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 29156025d..2b9b00569 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -6,11 +6,20 @@ from sqlalchemy.exc import IntegrityError from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User +# TODO: Find a better place to define these +default_vector_db_name = "lance.db" +default_vector_db_provider = "lancedb" +default_graph_db_provider = "kuzu" +default_vector_db_url = None +default_graph_db_url = None + async def get_or_create_dataset_database( dataset: Union[str, UUID], user: User, @@ -32,9 +41,12 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_db_name = f"{dataset_id}.lance.db" + vector_db_name = f"{dataset_id}.db" graph_db_name = f"{dataset_id}.pkl" + vector_config = get_vectordb_config() + graph_config = get_graph_config() + async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): @@ -49,12 +61,19 @@ async def get_or_create_dataset_database( if existing: return existing + # TODO: Set the vector and graph database stuff (name, provider, etc.) based on the whether or + # TODO: not we support multi user for that db. If not, set to default, which is lance and/or kuzu. + # If there are no existing rows build a new row record = DatasetDatabase( owner_id=user.id, dataset_id=dataset_id, vector_database_name=vector_db_name, graph_database_name=graph_db_name, + vector_database_provider=vector_config.vector_db_provider, + graph_database_provider=graph_config.graph_database_provider, + vector_database_url=vector_config.vector_db_url, + graph_database_url=graph_config.graph_database_url, ) try: diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 639bbb9f6..7e3fb367f 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,5 +1,6 @@ from .supported_databases import supported_databases from .embeddings import get_embedding_engine +from cognee.infrastructure.databases.graph.config import get_graph_config from functools import lru_cache @@ -45,6 +46,7 @@ def create_vector_engine( url=vector_db_url, api_key=vector_db_key, embedding_engine=embedding_engine, + graph_name=get_graph_config().graph_database_name ) if vector_db_provider == "pgvector": diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 0d71d8413..3d3899f4c 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -12,8 +12,15 @@ class DatasetDatabase(Base): UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True ) + # TODO: Why is this unique? Isn't it fact that two or more datasets can have the same vector and graph store? vector_database_name = Column(String, unique=True, nullable=False) graph_database_name = Column(String, unique=True, nullable=False) + vector_database_provider = Column(String, unique=True, nullable=False) + graph_database_provider = Column(String, unique=True, nullable=False) + + vector_database_url = Column(String, unique=True, nullable=True) + graph_database_url = Column(String, unique=True, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))