fix: PR comment fixes

This commit is contained in:
Andrej Milicevic 2025-10-29 16:30:13 +01:00
parent c3f0cb95da
commit 70f3ced15a
5 changed files with 44 additions and 39 deletions

View file

@ -70,7 +70,7 @@ def create_graph_engine(
graph_database_url=graph_database_url, graph_database_url=graph_database_url,
graph_database_username=graph_database_username, graph_database_username=graph_database_username,
graph_database_password=graph_database_password, graph_database_password=graph_database_password,
graph_name=graph_database_name, database_name=graph_database_name,
) )
if graph_database_provider == "neo4j": if graph_database_provider == "neo4j":

View file

@ -0,0 +1,4 @@
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
HYBRID_DBS = ["falkor"]

View file

@ -11,17 +11,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.modules.data.methods import get_unique_dataset_id from cognee.modules.data.methods import get_unique_dataset_id
from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User from cognee.modules.users.models import User
from .constants import (
GRAPH_DBS_WITH_MULTI_USER_SUPPORT,
# TODO: Find a better place to define these VECTOR_DBS_WITH_MULTI_USER_SUPPORT,
default_vector_db_provider = "lancedb" HYBRID_DBS,
default_graph_db_provider = "kuzu" )
default_vector_db_url = None
default_graph_db_url = None
default_vector_db_key = None
default_graph_db_key = None
vector_dbs_with_multi_user_support = ["lancedb", "falkor"]
graph_dbs_with_multi_user_support = ["kuzu", "falkor"]
async def get_or_create_dataset_database( async def get_or_create_dataset_database(
@ -45,12 +39,19 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user) dataset_id = await get_unique_dataset_id(dataset, user)
vector_db_name = f"{dataset_id}.db"
graph_db_name = f"{dataset_id}.pkl"
vector_config = get_vectordb_config() vector_config = get_vectordb_config()
graph_config = get_graph_config() graph_config = get_graph_config()
graph_db_name = f"{dataset_id}.pkl"
if graph_config.graph_database_provider in HYBRID_DBS:
vector_db_name = graph_db_name
else:
if vector_config.vector_database_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}.db"
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist # Create dataset if it doesn't exist
if isinstance(dataset, str): if isinstance(dataset, str):
@ -66,23 +67,18 @@ async def get_or_create_dataset_database(
return existing return existing
# Check if we support multi-user for this provider. If not, use default # Check if we support multi-user for this provider. If not, use default
if graph_config.graph_database_provider in graph_dbs_with_multi_user_support: if graph_config.graph_database_provider not in GRAPH_DBS_WITH_MULTI_USER_SUPPORT:
graph_provider = graph_config.graph_database_provider raise EnvironmentError(
graph_url = graph_config.graph_database_url f"Multi-user is currently not supported for the graph database provider: {graph_config.graph_database_provider}. "
graph_key = graph_config.graph_database_key f"Supported providers are: {', '.join(GRAPH_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
else: f"providers, or disable BACKEND_ACCESS_CONTROL"
graph_provider = default_graph_db_provider )
graph_url = default_graph_db_url if vector_config.vector_db_provider not in VECTOR_DBS_WITH_MULTI_USER_SUPPORT:
graph_key = default_graph_db_key raise EnvironmentError(
f"Multi-user is currently not supported for the vector database provider: {vector_config.vector_db_provider}. "
if vector_config.vector_db_provider in vector_dbs_with_multi_user_support: f"Supported providers are: {', '.join(VECTOR_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
vector_provider = vector_config.vector_db_provider f"providers, or disable BACKEND_ACCESS_CONTROL"
vector_url = vector_config.vector_db_url )
vector_key = vector_config.vector_db_key
else:
vector_provider = default_vector_db_provider
vector_url = default_vector_db_url
vector_key = default_vector_db_key
# If there are no existing rows build a new row # If there are no existing rows build a new row
record = DatasetDatabase( record = DatasetDatabase(
@ -90,12 +86,12 @@ async def get_or_create_dataset_database(
dataset_id=dataset_id, dataset_id=dataset_id,
vector_database_name=vector_db_name, vector_database_name=vector_db_name,
graph_database_name=graph_db_name, graph_database_name=graph_db_name,
vector_database_provider=vector_provider, vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_provider, graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_url, vector_database_url=vector_config.vector_db_url,
graph_database_url=graph_url, graph_database_url=graph_config.graph_database_url,
vector_database_key=vector_key, vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_key, graph_database_key=graph_config.graph_database_key,
) )
try: try:

View file

@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
Instance variables: Instance variables:
- vector_db_url: The URL of the vector database. - vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database. - vector_db_port: The port for the vector database.
- vector_db_name: The name of the vector database.
- vector_db_key: The key for accessing the vector database. - vector_db_key: The key for accessing the vector database.
- vector_db_provider: The provider for the vector database. - vector_db_provider: The provider for the vector database.
""" """
vector_db_url: str = "" vector_db_url: str = ""
vector_db_port: int = 1234 vector_db_port: int = 1234
vector_db_name: str = ""
vector_db_key: str = "" vector_db_key: str = ""
vector_db_provider: str = "lancedb" vector_db_provider: str = "lancedb"
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
return { return {
"vector_db_url": self.vector_db_url, "vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port, "vector_db_port": self.vector_db_port,
"vector_db_name": self.vector_db_name,
"vector_db_key": self.vector_db_key, "vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider, "vector_db_provider": self.vector_db_provider,
} }

View file

@ -9,6 +9,7 @@ from functools import lru_cache
def create_vector_engine( def create_vector_engine(
vector_db_provider: str, vector_db_provider: str,
vector_db_url: str, vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "", vector_db_port: str = "",
vector_db_key: str = "", vector_db_key: str = "",
): ):
@ -28,6 +29,7 @@ def create_vector_engine(
- vector_db_url (str): The URL for the vector database instance. - vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some - vector_db_port (str): The port for the vector database instance. Required for some
providers. providers.
- vector_db_name (str): The name of the vector database instance.
- vector_db_key (str): The API key or access token for the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g., - vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector'). 'pgvector').
@ -46,7 +48,7 @@ def create_vector_engine(
url=vector_db_url, url=vector_db_url,
api_key=vector_db_key, api_key=vector_db_key,
embedding_engine=embedding_engine, embedding_engine=embedding_engine,
graph_name=get_graph_context_config()["graph_database_name"], database_name=vector_db_name,
) )
if vector_db_provider == "pgvector": if vector_db_provider == "pgvector":