feat: Add initial multi tenant neo4j support

This commit is contained in:
Igor Ilic 2025-11-11 19:44:34 +01:00
parent a8706a2fd2
commit 432d4a1578
3 changed files with 100 additions and 28 deletions

View file

@ -17,7 +17,7 @@ graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None) session_user = ContextVar("session_user", default=None)
VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor", "neo4j"]
async def set_session_user_context_variable(user): async def set_session_user_context_variable(user):
@ -101,6 +101,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"graph_file_path": os.path.join( "graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name databases_directory_path, dataset_database.graph_database_name
), ),
"graph_database_username": dataset_database.graph_database_username,
"graph_database_password": dataset_database.graph_database_password,
} }
storage_config = { storage_config = {

View file

@ -39,6 +39,20 @@ async def get_or_create_dataset_database(
vector_config = get_vectordb_config() vector_config = get_vectordb_config()
graph_config = get_graph_config() graph_config = get_graph_config()
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
dataset = await create_dataset(dataset, user, session)
# Try to fetch an existing row first
stmt = select(DatasetDatabase).where(
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id,
)
existing: DatasetDatabase = await session.scalar(stmt)
if existing:
return existing
# Note: for hybrid databases both graph and vector DB name have to be the same # Note: for hybrid databases both graph and vector DB name have to be the same
if graph_config.graph_database_provider == "kuzu": if graph_config.graph_database_provider == "kuzu":
graph_db_name = f"{dataset_id}.pkl" graph_db_name = f"{dataset_id}.pkl"
@ -62,22 +76,73 @@ async def get_or_create_dataset_database(
vector_db_url = vector_config.vector_database_url vector_db_url = vector_config.vector_database_url
# Determine graph database URL # Determine graph database URL
if graph_config.graph_database_provider == "neo4j":
# Auto deploy instance to Aura DB
# OAuth2 token endpoint
async with db_engine.get_async_session() as session: # Your client credentials
# Create dataset if it doesn't exist client_id = os.environ.get("NEO4J_CLIENT_ID", None)
if isinstance(dataset, str): client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
dataset = await create_dataset(dataset, user, session) tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
# Try to fetch an existing row first # Make the request with HTTP Basic Auth
stmt = select(DatasetDatabase).where( import requests
DatasetDatabase.owner_id == user.id,
DatasetDatabase.dataset_id == dataset_id, def get_aura_token(client_id: str, client_secret: str) -> dict:
) url = "https://api.neo4j.io/oauth/token"
existing: DatasetDatabase = await session.scalar(stmt) data = {
if existing: "grant_type": "client_credentials"
return existing } # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
payload = {
"version": "5",
"region": "europe-west1",
"memory": "1GB",
"name": graph_db_name[0:29],
"type": "professional-db",
"tenant_id": tenant_id,
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
# Wait for instance to be provisioned
# TODO: Find better way to check when instance is ready
import asyncio
await asyncio.sleep(180)
print(response.status_code)
print(response.text)
# TODO: Find better name to name Neo4j instance within 30 character limit
print(graph_db_name[0:29])
graph_db_name = "neo4j"
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
else:
graph_db_url = graph_config.graph_database_url
graph_db_key = graph_config.graph_database_key
graph_db_username = graph_config.graph_database_username
graph_db_password = graph_config.graph_database_password
# If there are no existing rows build a new row # If there are no existing rows build a new row
# TODO: Update Dataset Database migrations, also make sure database_name is not unique anymore
record = DatasetDatabase( record = DatasetDatabase(
owner_id=user.id, owner_id=user.id,
dataset_id=dataset_id, dataset_id=dataset_id,
@ -86,9 +151,11 @@ async def get_or_create_dataset_database(
vector_database_provider=vector_config.vector_db_provider, vector_database_provider=vector_config.vector_db_provider,
graph_database_provider=graph_config.graph_database_provider, graph_database_provider=graph_config.graph_database_provider,
vector_database_url=vector_db_url, vector_database_url=vector_db_url,
graph_database_url=graph_config.graph_database_url, graph_database_url=graph_db_url,
vector_database_key=vector_config.vector_db_key, vector_database_key=vector_config.vector_db_key,
graph_database_key=graph_config.graph_database_key, graph_database_key=graph_db_key,
graph_database_username=graph_db_username,
graph_database_password=graph_db_password,
) )
try: try:

View file

@ -13,7 +13,7 @@ class DatasetDatabase(Base):
) )
vector_database_name = Column(String, unique=True, nullable=False) vector_database_name = Column(String, unique=True, nullable=False)
graph_database_name = Column(String, unique=True, nullable=False) graph_database_name = Column(String, unique=False, nullable=False)
vector_database_provider = Column(String, unique=False, nullable=False) vector_database_provider = Column(String, unique=False, nullable=False)
graph_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False)
@ -24,5 +24,8 @@ class DatasetDatabase(Base):
vector_database_key = Column(String, unique=False, nullable=True) vector_database_key = Column(String, unique=False, nullable=True)
graph_database_key = Column(String, unique=False, nullable=True) graph_database_key = Column(String, unique=False, nullable=True)
graph_database_username = Column(String, unique=False, nullable=True)
graph_database_password = Column(String, unique=False, nullable=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))