fix: PR comment fixes

This commit is contained in:
Andrej Milicevic 2025-10-29 16:30:13 +01:00
parent c3f0cb95da
commit 70f3ced15a
5 changed files with 44 additions and 39 deletions

View file

@ -70,7 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
graph_name=graph_database_name,
database_name=graph_database_name,
)
if graph_database_provider == "neo4j":

View file

@ -0,0 +1,4 @@
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
HYBRID_DBS = ["falkor"]

View file

@ -11,17 +11,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User
# TODO: Find a better place to define these
default_vector_db_provider = "lancedb"
default_graph_db_provider = "kuzu"
default_vector_db_url = None
default_graph_db_url = None
default_vector_db_key = None
default_graph_db_key = None
vector_dbs_with_multi_user_support = ["lancedb", "falkor"]
graph_dbs_with_multi_user_support = ["kuzu", "falkor"]
from .constants import (
GRAPH_DBS_WITH_MULTI_USER_SUPPORT,
VECTOR_DBS_WITH_MULTI_USER_SUPPORT,
HYBRID_DBS,
)
async def get_or_create_dataset_database(
@ -45,12 +39,19 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user)
vector_db_name = f"{dataset_id}.db"
graph_db_name = f"{dataset_id}.pkl"
vector_config = get_vectordb_config()
graph_config = get_graph_config()
graph_db_name = f"{dataset_id}.pkl"
if graph_config.graph_database_provider in HYBRID_DBS:
vector_db_name = graph_db_name
else:
if vector_config.vector_database_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}.db"
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
@ -66,23 +67,18 @@ async def get_or_create_dataset_database(
return existing
# Check if we support multi-user for this provider. If not, use default
if graph_config.graph_database_provider in graph_dbs_with_multi_user_support:
graph_provider = graph_config.graph_database_provider
graph_url = graph_config.graph_database_url
graph_key = graph_config.graph_database_key
else:
graph_provider = default_graph_db_provider
graph_url = default_graph_db_url
graph_key = default_graph_db_key
if vector_config.vector_db_provider in vector_dbs_with_multi_user_support:
vector_provider = vector_config.vector_db_provider
vector_url = vector_config.vector_db_url
vector_key = vector_config.vector_db_key
else:
vector_provider = default_vector_db_provider
vector_url = default_vector_db_url
vector_key = default_vector_db_key
if graph_config.graph_database_provider not in GRAPH_DBS_WITH_MULTI_USER_SUPPORT:
raise EnvironmentError(
f"Multi-user is currently not supported for the graph database provider: {graph_config.graph_database_provider}. "
f"Supported providers are: {', '.join(GRAPH_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
f"providers, or disable BACKEND_ACCESS_CONTROL"
)
if vector_config.vector_db_provider not in VECTOR_DBS_WITH_MULTI_USER_SUPPORT:
raise EnvironmentError(
f"Multi-user is currently not supported for the vector database provider: {vector_config.vector_db_provider}. "
f"Supported providers are: {', '.join(VECTOR_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
f"providers, or disable BACKEND_ACCESS_CONTROL"
)
# If there are no existing rows build a new row
record = DatasetDatabase(
@ -90,12 +86,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id,
vector_database_name=vector_db_name,
graph_database_name=graph_db_name,
vector_database_provider=vector_provider,
graph_database_provider=graph_provider,
vector_database_url=vector_url,
graph_database_url=graph_url,
vector_database_key=vector_key,
graph_database_key=graph_key,
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_config.vector_db_url,
graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key,
)
try:

View file

@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables:
- vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database.
- vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database.
"""
vector_db_url: str = ""
vector_db_port: int = 1234
vector_db_name: str = ""
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}

View file

@ -9,6 +9,7 @@ from functools import lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
):
@ -28,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some
providers.
- vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector').
@ -46,7 +48,7 @@ def create_vector_engine(
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
graph_name=get_graph_context_config()["graph_database_name"],
database_name=vector_db_name,
)
if vector_db_provider == "pgvector":