refactor: Make neo4j auto scaling more readable

This commit is contained in:
Igor Ilic 2025-11-12 17:58:27 +01:00
parent a0a14e7ccc
commit b017fcc8d0

View file

@ -1,6 +1,8 @@
import os import os
import asyncio
import requests
from uuid import UUID from uuid import UUID
from typing import Union from typing import Union, Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -15,6 +17,157 @@ from cognee.modules.users.models import DatasetDatabase
from cognee.modules.users.models import User from cognee.modules.users.models import User
async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict:
vector_config = get_vectordb_config()
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector configuration
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
vector_db_url = os.path.join(databases_directory_path, vector_db_name)
else:
# Note: for hybrid databases both graph and vector DB name have to be the same
vector_db_name = vector_config.vector_db_name
vector_db_url = vector_config.vector_database_url
return {
"vector_database_name": vector_db_name,
"vector_database_url": vector_db_url,
"vector_database_provider": vector_config.vector_db_provider,
"vector_database_key": vector_config.vector_db_key,
}
async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict:
graph_config = get_graph_config()
# Determine graph database URL
if graph_config.graph_database_provider == "neo4j":
graph_db_name = f"{dataset_id}"
# Auto deploy instance to Aura DB
# OAuth2 token endpoint
# Your client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[0:29],
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
print(response.status_code)
print(response.text)
# TODO: Find better name to name Neo4j instance within 30 character limit
print(graph_db_name[0:29])
graph_db_name = "neo4j"
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(status_url, headers=headers)
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
elif graph_config.graph_database_provider == "kuzu":
# TODO: Add graph file path info for kuzu (also in DatasetDatabase model)
graph_db_name = f"{dataset_id}.pkl"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
elif graph_config.graph_database_provider == "falkor":
# Note: for hybrid databases both graph and vector DB name have to be the same
graph_db_name = f"{dataset_id}"
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
else:
raise EnvironmentError(
f"Unsupported graph database provider for backend access control: {graph_config.graph_database_provider}"
)
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": graph_config.graph_database_provider,
"graph_database_key": graph_db_key,
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
}
async def _existing_dataset_database(
dataset_id: UUID,
user: User,
) -> Optional[DatasetDatabase]:
"""
Check if a DatasetDatabase row already exists for the given owner + dataset.
Return None if it doesn't exist, return the row if it does.
Args:
dataset_id:
user:
Returns:
DatasetDatabase or None
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
return existing
async def get_or_create_dataset_database( async def get_or_create_dataset_database(
dataset: Union[str, UUID], dataset: Union[str, UUID],
user: User, user: User,
@ -36,150 +189,27 @@ async def get_or_create_dataset_database(
dataset_id = await get_unique_dataset_id(dataset, user) dataset_id = await get_unique_dataset_id(dataset, user)
vector_config = get_vectordb_config() # If dataset is given as name make sure the dataset is created first
graph_config = get_graph_config() if isinstance(dataset, str):
async with db_engine.get_async_session() as session:
await create_dataset(dataset, user, session)
# Note: for hybrid databases both graph and vector DB name have to be the same # If dataset database already exists return it
if graph_config.graph_database_provider == "kuzu": existing_dataset_database = await _existing_dataset_database(dataset_id, user)
graph_db_name = f"{dataset_id}.pkl" if existing_dataset_database:
else: return existing_dataset_database
graph_db_name = f"{dataset_id}"
if vector_config.vector_db_provider == "lancedb": graph_config_dict = await _get_graph_db_info(dataset_id, user)
vector_db_name = f"{dataset_id}.lance.db" vector_config_dict = await _get_vector_db_info(dataset_id, user)
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
dataset = await create_dataset(dataset, user, session)
# Try to fetch an existing row first
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
if existing:
return existing
# Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl"
else:
graph_db_name = f"{dataset_id}"
if vector_config.vector_db_provider == "lancedb":
vector_db_name = f"{dataset_id}.lance.db"
else:
vector_db_name = f"{dataset_id}"
base_config = get_base_config()
databases_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Determine vector database URL
if vector_config.vector_db_provider == "lancedb":
vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name)
else:
vector_db_url = vector_config.vector_database_url
# Determine graph database URL
if graph_config.graph_database_provider == "neo4j":
# Auto deploy instance to Aura DB
# OAuth2 token endpoint
# Your client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
# Make the request with HTTP Basic Auth
import requests
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {
"grant_type": "client_credentials"
} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[0:29],
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
# Wait for instance to be provisioned
# TODO: Find better way to check when instance is ready
import asyncio
await asyncio.sleep(180)
print(response.status_code)
print(response.text)
# TODO: Find better name to name Neo4j instance within 30 character limit
print(graph_db_name[0:29])
graph_db_name = "neo4j"
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
else:
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
# If there are no existing rows build a new row # If there are no existing rows build a new row
# TODO: Update Dataset Database migrations, also make sure database_name is not unique anymore # TODO: Update Dataset Database migrations, also make sure database_name is not unique anymore
record = DatasetDatabase( record = DatasetDatabase(
owner_id=user.id, owner_id=user.id,
dataset_id=dataset_id, dataset_id=dataset_id,
vector_database_name=vector_db_name, **graph_config_dict, # Unpack graph db config
graph_database_name=graph_db_name, **vector_config_dict, # Unpack vector db config
vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url,
graph_database_url=graph_db_url,
vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_db_key,
graph_database_username=graph_db_username,
graph_database_password=graph_db_password,
) )
try: try: