Merge branch 'dev' into feature/cog-3645-enable-multi-user-support-for-pgvector

This commit is contained in:
Igor Ilic 2026-01-20 15:51:57 +01:00
commit 05084e6779
12 changed files with 157 additions and 47 deletions

View file

@ -427,10 +427,12 @@ git checkout -b feature/your-feature-name
## Code Style
- Ruff for linting and formatting (configured in `pyproject.toml`)
- Line length: 100 characters
- Pre-commit hooks run ruff automatically
- Type hints encouraged (mypy checks enabled)
- **Formatter**: Ruff (configured in `pyproject.toml`)
- **Line length**: 100 characters
- **String quotes**: Use double quotes `"` not single quotes `'` (enforced by ruff-format)
- **Pre-commit hooks**: Run ruff linting and formatting automatically
- **Type hints**: Encouraged (mypy checks enabled)
- **Important**: Always run `pre-commit run --all-files` before committing to catch formatting issues
## Testing Strategy

View file

@ -252,7 +252,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
chunk_size: int = None,
config: Config = None,
custom_prompt: Optional[str] = None,
chunks_per_batch: int = 100,
chunks_per_batch: int = None,
**kwargs,
) -> list[Task]:
if config is None:
@ -272,12 +272,14 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
}
if chunks_per_batch is None:
chunks_per_batch = 100
cognify_config = get_cognify_config()
embed_triplets = cognify_config.triplet_embedding
if chunks_per_batch is None:
chunks_per_batch = (
cognify_config.chunks_per_batch if cognify_config.chunks_per_batch is not None else 100
)
default_tasks = [
Task(classify_documents),
Task(
@ -308,7 +310,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
async def get_temporal_tasks(
user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = 10
user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = None
) -> list[Task]:
"""
Builds and returns a list of temporal processing tasks to be executed in sequence.
@ -330,7 +332,10 @@ async def get_temporal_tasks(
list[Task]: A list of Task objects representing the temporal processing pipeline.
"""
if chunks_per_batch is None:
chunks_per_batch = 10
from cognee.modules.cognify.config import get_cognify_config
configured = get_cognify_config().chunks_per_batch
chunks_per_batch = configured if configured is not None else 10
temporal_tasks = [
Task(classify_documents),

View file

@ -46,6 +46,11 @@ class CognifyPayloadDTO(InDTO):
examples=[[]],
description="Reference to one or more previously uploaded ontologies",
)
chunks_per_batch: Optional[int] = Field(
default=None,
description="Number of chunks to process per task batch in Cognify (overrides default).",
examples=[10, 20, 50, 100],
)
def get_cognify_router() -> APIRouter:
@ -146,6 +151,7 @@ def get_cognify_router() -> APIRouter:
config=config_to_use,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
chunks_per_batch=payload.chunks_per_batch,
)
# If any cognify run errored return JSONResponse with proper error status code

View file

@ -62,6 +62,11 @@ After successful cognify processing, use `cognee search` to query the knowledge
parser.add_argument(
"--verbose", "-v", action="store_true", help="Show detailed progress information"
)
parser.add_argument(
"--chunks-per-batch",
type=int,
help="Number of chunks to process per task batch (try 50 for large single documents).",
)
def execute(self, args: argparse.Namespace) -> None:
try:
@ -111,6 +116,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
chunk_size=args.chunk_size,
ontology_file_path=args.ontology_file,
run_in_background=args.background,
chunks_per_batch=getattr(args, "chunks_per_batch", None),
)
return result
except Exception as e:

View file

@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface:
return graph_client
@lru_cache
def create_graph_engine(
graph_database_provider,
graph_file_path,
@ -35,6 +34,35 @@ def create_graph_engine(
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Wrapper function to call create graph engine with caching.
For a detailed description, see _create_graph_engine.
"""
return _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url,
graph_database_name,
graph_database_username,
graph_database_password,
graph_database_port,
graph_database_key,
graph_dataset_database_handler,
)
@lru_cache
def _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url="",
graph_database_name="",
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -1,11 +1,13 @@
import os
import aiohttp
import asyncio
import requests
import base64
import hashlib
from uuid import UUID
from typing import Optional
from urllib.parse import urlparse
from cryptography.fernet import Fernet
from aiohttp import BasicAuth
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User, DatasetDatabase
@ -23,7 +25,6 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
Quality of life improvements:
- Allow configuration of different Neo4j Aura plans and regions.
- Requests should be made async, currently a blocking requests library is used.
"""
@classmethod
@ -49,6 +50,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
graph_db_name = f"{dataset_id}"
# Client credentials and encryption
# Note: Should not be used as class variables so that they are not persisted in memory longer than needed
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
@ -63,22 +65,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
)
# Make the request with HTTP Basic Auth
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
resp_token = await cls._get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances"
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Authorization": f"Bearer {resp_token['access_token']}",
"Content-Type": "application/json",
}
@ -96,31 +89,38 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"cloud_provider": "gcp",
}
response = requests.post(url, headers=headers, json=payload)
async def _create_database_instance_request():
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as resp:
resp.raise_for_status()
return await resp.json()
resp_create = await _create_database_instance_request()
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
graph_db_url = response.json()["data"]["connection_url"]
graph_db_key = resp["access_token"]
graph_db_username = response.json()["data"]["username"]
graph_db_password = response.json()["data"]["password"]
graph_db_url = resp_create["data"]["connection_url"]
graph_db_key = resp_token["access_token"]
graph_db_username = resp_create["data"]["username"]
graph_db_password = resp_create["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(
status_url, headers=headers
) # TODO: Use async requests with httpx
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
async with aiohttp.ClientSession() as session:
async with session.get(status_url, headers=headers) as resp:
resp.raise_for_status()
status_resp = await resp.json()
status = status_resp["data"]["status"]
if status.lower() == "running":
return
await asyncio.sleep(10)
raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
)
instance_id = response.json()["data"]["id"]
instance_id = resp_create["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
@ -165,4 +165,39 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase):
pass
# Get dataset database information and credentials
dataset_database = await cls.resolve_dataset_connection_info(dataset_database)
parsed_url = urlparse(dataset_database.graph_database_url)
instance_id = parsed_url.hostname.split(".")[0]
url = f"https://api.neo4j.io/v1/instances/{instance_id}"
# Get access token for Neo4j Aura API
# Client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
resp = await cls._get_aura_token(client_id, client_secret)
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.delete(url, headers=headers) as resp:
resp.raise_for_status()
return await resp.json()
@classmethod
async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
async with aiohttp.ClientSession() as session:
async with session.post(
url, data=data, auth=BasicAuth(client_id, client_secret)
) as resp:
resp.raise_for_status()
return await resp.json()

View file

@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi
from functools import lru_cache
@lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
@ -17,6 +16,29 @@ def create_vector_engine(
vector_dataset_database_handler: str = "",
vector_db_username: str = "",
vector_db_password: str = "",
):
"""
Wrapper function to call create vector engine with caching.
For a detailed description, see _create_vector_engine.
"""
return _create_vector_engine(
vector_db_provider,
vector_db_url,
vector_db_name,
vector_db_port,
vector_db_key,
vector_dataset_database_handler,
)
@lru_cache
def _create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -9,6 +9,7 @@ class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent
triplet_embedding: bool = False
chunks_per_batch: Optional[int] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
@ -16,6 +17,7 @@ class CognifyConfig(BaseSettings):
"classification_model": self.classification_model,
"summarization_model": self.summarization_model,
"triplet_embedding": self.triplet_embedding,
"chunks_per_batch": self.chunks_per_batch,
}

View file

@ -238,6 +238,7 @@ class TestCognifyCommand:
ontology_file_path=None,
chunker=TextChunker,
run_in_background=False,
chunks_per_batch=None,
)
@patch("cognee.cli.commands.cognify_command.asyncio.run")

View file

@ -262,6 +262,7 @@ class TestCognifyCommandEdgeCases:
ontology_file_path=None,
chunker=TextChunker,
run_in_background=False,
chunks_per_batch=None,
)
@patch("cognee.cli.commands.cognify_command.asyncio.run", side_effect=_mock_run)
@ -295,6 +296,7 @@ class TestCognifyCommandEdgeCases:
ontology_file_path="/nonexistent/path/ontology.owl",
chunker=TextChunker,
run_in_background=False,
chunks_per_batch=None,
)
@patch("cognee.cli.commands.cognify_command.asyncio.run")
@ -373,6 +375,7 @@ class TestCognifyCommandEdgeCases:
ontology_file_path=None,
chunker=TextChunker,
run_in_background=False,
chunks_per_batch=None,
)

View file

@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None:
except Exception:
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()

View file

@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None:
# Engine might not exist yet
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()