diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 62e06fc64..44ead95af 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -17,7 +17,7 @@ graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] -GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] +GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor", "neo4j"] async def set_session_user_context_variable(user): @@ -101,6 +101,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), + "graph_database_username": dataset_database.graph_database_username, + "graph_database_password": dataset_database.graph_database_password, } storage_config = { diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 3684bb100..0a2638dc5 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -39,30 +39,6 @@ async def get_or_create_dataset_database( vector_config = get_vectordb_config() graph_config = get_graph_config() - # Note: for hybrid databases both graph and vector DB name have to be the same - if graph_config.graph_database_provider == "kuzu": - graph_db_name = f"{dataset_id}.pkl" - else: - graph_db_name = f"{dataset_id}" - - if vector_config.vector_db_provider == "lancedb": - vector_db_name = f"{dataset_id}.lance.db" - else: - vector_db_name = f"{dataset_id}" - - base_config = get_base_config() - databases_directory_path = os.path.join( - base_config.system_root_directory, "databases", str(user.id) - ) - - # Determine vector database URL - if vector_config.vector_db_provider == "lancedb": - vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) - else: - vector_db_url = vector_config.vector_database_url - - # Determine graph database URL - async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): @@ -77,7 +53,96 @@ async def get_or_create_dataset_database( if existing: return existing + # Note: for hybrid databases both graph and vector DB name have to be the same + if graph_config.graph_database_provider == "kuzu": + graph_db_name = f"{dataset_id}.pkl" + else: + graph_db_name = f"{dataset_id}" + + if vector_config.vector_db_provider == "lancedb": + vector_db_name = f"{dataset_id}.lance.db" + else: + vector_db_name = f"{dataset_id}" + + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + # Determine vector database URL + if vector_config.vector_db_provider == "lancedb": + vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) + else: + vector_db_url = vector_config.vector_database_url + + # Determine graph database URL + if graph_config.graph_database_provider == "neo4j": + # Auto deploy instance to Aura DB + # OAuth2 token endpoint + + # Your client credentials + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + tenant_id = os.environ.get("NEO4J_TENANT_ID", None) + + # Make the request with HTTP Basic Auth + import requests + + def get_aura_token(client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = { + "grant_type": "client_credentials" + } # sent as application/x-www-form-urlencoded + + resp = requests.post(url, data=data, auth=(client_id, client_secret)) + resp.raise_for_status() # raises if the request failed + return resp.json() + + resp = get_aura_token(client_id, client_secret) + + url = "https://api.neo4j.io/v1/instances" + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + payload = { + "version": "5", + "region": "europe-west1", + "memory": "1GB", + "name": graph_db_name[0:29], + "type": "professional-db", + "tenant_id": tenant_id, + "cloud_provider": "gcp", + } + + response = requests.post(url, headers=headers, json=payload) + + # Wait for instance to be provisioned + # TODO: Find better way to check when instance is ready + import asyncio + + await asyncio.sleep(180) + + print(response.status_code) + print(response.text) + # TODO: Find better name to name Neo4j instance within 30 character limit + print(graph_db_name[0:29]) + graph_db_name = "neo4j" + graph_db_url = response.json()["data"]["connection_url"] + graph_db_key = resp["access_token"] + graph_db_username = response.json()["data"]["username"] + graph_db_password = response.json()["data"]["password"] + else: + graph_db_url = graph_config.graph_database_url + graph_db_key = graph_config.graph_database_key + graph_db_username = graph_config.graph_database_username + graph_db_password = graph_config.graph_database_password + # If there are no existing rows build a new row + # TODO: Update Dataset Database migrations, also make sure database_name is not unique anymore record = DatasetDatabase( owner_id=user.id, dataset_id=dataset_id, @@ -86,9 +151,11 @@ async def get_or_create_dataset_database( vector_database_provider=vector_config.vector_db_provider, graph_database_provider=graph_config.graph_database_provider, vector_database_url=vector_db_url, - graph_database_url=graph_config.graph_database_url, + graph_database_url=graph_db_url, vector_database_key=vector_config.vector_db_key, - graph_database_key=graph_config.graph_database_key, + graph_database_key=graph_db_key, + graph_database_username=graph_db_username, + graph_database_password=graph_db_password, ) try: diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 25d610ab9..5d2e4fcd5 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -13,7 +13,7 @@ class DatasetDatabase(Base): ) vector_database_name = Column(String, unique=True, nullable=False) - graph_database_name = Column(String, unique=True, nullable=False) + graph_database_name = Column(String, unique=False, nullable=False) vector_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False) @@ -24,5 +24,8 @@ class DatasetDatabase(Base): vector_database_key = Column(String, unique=False, nullable=True) graph_database_key = Column(String, unique=False, nullable=True) + graph_database_username = Column(String, unique=False, nullable=True) + graph_database_password = Column(String, unique=False, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))