fix: PR comment fixes
This commit is contained in:
parent
c3f0cb95da
commit
70f3ced15a
5 changed files with 44 additions and 39 deletions
|
|
@ -70,7 +70,7 @@ def create_graph_engine(
|
|||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
graph_name=graph_database_name,
|
||||
database_name=graph_database_name,
|
||||
)
|
||||
|
||||
if graph_database_provider == "neo4j":
|
||||
|
|
|
|||
4
cognee/infrastructure/databases/utils/constants.py
Normal file
4
cognee/infrastructure/databases/utils/constants.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||
|
||||
HYBRID_DBS = ["falkor"]
|
||||
|
|
@ -11,17 +11,11 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
|
|||
from cognee.modules.data.methods import get_unique_dataset_id
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
# TODO: Find a better place to define these
|
||||
default_vector_db_provider = "lancedb"
|
||||
default_graph_db_provider = "kuzu"
|
||||
default_vector_db_url = None
|
||||
default_graph_db_url = None
|
||||
default_vector_db_key = None
|
||||
default_graph_db_key = None
|
||||
vector_dbs_with_multi_user_support = ["lancedb", "falkor"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu", "falkor"]
|
||||
from .constants import (
|
||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT,
|
||||
VECTOR_DBS_WITH_MULTI_USER_SUPPORT,
|
||||
HYBRID_DBS,
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_dataset_database(
|
||||
|
|
@ -45,12 +39,19 @@ async def get_or_create_dataset_database(
|
|||
|
||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||
|
||||
vector_db_name = f"{dataset_id}.db"
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
graph_config = get_graph_config()
|
||||
|
||||
graph_db_name = f"{dataset_id}.pkl"
|
||||
|
||||
if graph_config.graph_database_provider in HYBRID_DBS:
|
||||
vector_db_name = graph_db_name
|
||||
else:
|
||||
if vector_config.vector_database_provider == "lancedb":
|
||||
vector_db_name = f"{dataset_id}.lance.db"
|
||||
else:
|
||||
vector_db_name = f"{dataset_id}.db"
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
if isinstance(dataset, str):
|
||||
|
|
@ -66,23 +67,18 @@ async def get_or_create_dataset_database(
|
|||
return existing
|
||||
|
||||
# Check if we support multi-user for this provider. If not, use default
|
||||
if graph_config.graph_database_provider in graph_dbs_with_multi_user_support:
|
||||
graph_provider = graph_config.graph_database_provider
|
||||
graph_url = graph_config.graph_database_url
|
||||
graph_key = graph_config.graph_database_key
|
||||
else:
|
||||
graph_provider = default_graph_db_provider
|
||||
graph_url = default_graph_db_url
|
||||
graph_key = default_graph_db_key
|
||||
|
||||
if vector_config.vector_db_provider in vector_dbs_with_multi_user_support:
|
||||
vector_provider = vector_config.vector_db_provider
|
||||
vector_url = vector_config.vector_db_url
|
||||
vector_key = vector_config.vector_db_key
|
||||
else:
|
||||
vector_provider = default_vector_db_provider
|
||||
vector_url = default_vector_db_url
|
||||
vector_key = default_vector_db_key
|
||||
if graph_config.graph_database_provider not in GRAPH_DBS_WITH_MULTI_USER_SUPPORT:
|
||||
raise EnvironmentError(
|
||||
f"Multi-user is currently not supported for the graph database provider: {graph_config.graph_database_provider}. "
|
||||
f"Supported providers are: {', '.join(GRAPH_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
|
||||
f"providers, or disable BACKEND_ACCESS_CONTROL"
|
||||
)
|
||||
if vector_config.vector_db_provider not in VECTOR_DBS_WITH_MULTI_USER_SUPPORT:
|
||||
raise EnvironmentError(
|
||||
f"Multi-user is currently not supported for the vector database provider: {vector_config.vector_db_provider}. "
|
||||
f"Supported providers are: {', '.join(VECTOR_DBS_WITH_MULTI_USER_SUPPORT)}. Either use one of these"
|
||||
f"providers, or disable BACKEND_ACCESS_CONTROL"
|
||||
)
|
||||
|
||||
# If there are no existing rows build a new row
|
||||
record = DatasetDatabase(
|
||||
|
|
@ -90,12 +86,12 @@ async def get_or_create_dataset_database(
|
|||
dataset_id=dataset_id,
|
||||
vector_database_name=vector_db_name,
|
||||
graph_database_name=graph_db_name,
|
||||
vector_database_provider=vector_provider,
|
||||
graph_database_provider=graph_provider,
|
||||
vector_database_url=vector_url,
|
||||
graph_database_url=graph_url,
|
||||
vector_database_key=vector_key,
|
||||
graph_database_key=graph_key,
|
||||
vector_database_provider=vector_config.vector_db_provider,
|
||||
graph_database_provider=graph_config.graph_database_provider,
|
||||
vector_database_url=vector_config.vector_db_url,
|
||||
graph_database_url=graph_config.graph_database_url,
|
||||
vector_database_key=vector_config.vector_db_key,
|
||||
graph_database_key=graph_config.graph_database_key,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ class VectorConfig(BaseSettings):
|
|||
Instance variables:
|
||||
- vector_db_url: The URL of the vector database.
|
||||
- vector_db_port: The port for the vector database.
|
||||
- vector_db_name: The name of the vector database.
|
||||
- vector_db_key: The key for accessing the vector database.
|
||||
- vector_db_provider: The provider for the vector database.
|
||||
"""
|
||||
|
||||
vector_db_url: str = ""
|
||||
vector_db_port: int = 1234
|
||||
vector_db_name: str = ""
|
||||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
|
||||
|
|
@ -58,6 +60,7 @@ class VectorConfig(BaseSettings):
|
|||
return {
|
||||
"vector_db_url": self.vector_db_url,
|
||||
"vector_db_port": self.vector_db_port,
|
||||
"vector_db_name": self.vector_db_name,
|
||||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from functools import lru_cache
|
|||
def create_vector_engine(
|
||||
vector_db_provider: str,
|
||||
vector_db_url: str,
|
||||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
):
|
||||
|
|
@ -28,6 +29,7 @@ def create_vector_engine(
|
|||
- vector_db_url (str): The URL for the vector database instance.
|
||||
- vector_db_port (str): The port for the vector database instance. Required for some
|
||||
providers.
|
||||
- vector_db_name (str): The name of the vector database instance.
|
||||
- vector_db_key (str): The API key or access token for the vector database instance.
|
||||
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
|
||||
'pgvector').
|
||||
|
|
@ -46,7 +48,7 @@ def create_vector_engine(
|
|||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
graph_name=get_graph_context_config()["graph_database_name"],
|
||||
database_name=vector_db_name,
|
||||
)
|
||||
|
||||
if vector_db_provider == "pgvector":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue