diff --git a/CLAUDE.md b/CLAUDE.md index 7ac4f01d0..4303582c2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -427,10 +427,12 @@ git checkout -b feature/your-feature-name ## Code Style -- Ruff for linting and formatting (configured in `pyproject.toml`) -- Line length: 100 characters -- Pre-commit hooks run ruff automatically -- Type hints encouraged (mypy checks enabled) +- **Formatter**: Ruff (configured in `pyproject.toml`) +- **Line length**: 100 characters +- **String quotes**: Use double quotes `"` not single quotes `'` (enforced by ruff-format) +- **Pre-commit hooks**: Run ruff linting and formatting automatically +- **Type hints**: Encouraged (mypy checks enabled) +- **Important**: Always run `pre-commit run --all-files` before committing to catch formatting issues ## Testing Strategy diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index ffc903d68..bbe00c35f 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -252,7 +252,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's chunk_size: int = None, config: Config = None, custom_prompt: Optional[str] = None, - chunks_per_batch: int = 100, + chunks_per_batch: int = None, **kwargs, ) -> list[Task]: if config is None: @@ -272,12 +272,14 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } - if chunks_per_batch is None: - chunks_per_batch = 100 - cognify_config = get_cognify_config() embed_triplets = cognify_config.triplet_embedding + if chunks_per_batch is None: + chunks_per_batch = ( + cognify_config.chunks_per_batch if cognify_config.chunks_per_batch is not None else 100 + ) + default_tasks = [ Task(classify_documents), Task( @@ -308,7 +310,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's async def get_temporal_tasks( - user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = 10 + user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = None ) -> list[Task]: """ Builds and returns a list of temporal processing tasks to be executed in sequence. @@ -330,7 +332,10 @@ async def get_temporal_tasks( list[Task]: A list of Task objects representing the temporal processing pipeline. """ if chunks_per_batch is None: - chunks_per_batch = 10 + from cognee.modules.cognify.config import get_cognify_config + + configured = get_cognify_config().chunks_per_batch + chunks_per_batch = configured if configured is not None else 10 temporal_tasks = [ Task(classify_documents), diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index a499b3ca3..0e2bf2bda 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -46,6 +46,11 @@ class CognifyPayloadDTO(InDTO): examples=[[]], description="Reference to one or more previously uploaded ontologies", ) + chunks_per_batch: Optional[int] = Field( + default=None, + description="Number of chunks to process per task batch in Cognify (overrides default).", + examples=[10, 20, 50, 100], + ) def get_cognify_router() -> APIRouter: @@ -146,6 +151,7 @@ def get_cognify_router() -> APIRouter: config=config_to_use, run_in_background=payload.run_in_background, custom_prompt=payload.custom_prompt, + chunks_per_batch=payload.chunks_per_batch, ) # If any cognify run errored return JSONResponse with proper error status code diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index b89c1f70e..c310b88b7 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -62,6 +62,11 @@ After successful cognify processing, use `cognee search` to query the knowledge parser.add_argument( "--verbose", "-v", action="store_true", help="Show detailed progress information" ) + parser.add_argument( + "--chunks-per-batch", + type=int, + help="Number of chunks to process per task batch (try 50 for large single documents).", + ) def execute(self, args: argparse.Namespace) -> None: try: @@ -111,6 +116,7 @@ After successful cognify processing, use `cognee search` to query the knowledge chunk_size=args.chunk_size, ontology_file_path=args.ontology_file, run_in_background=args.background, + chunks_per_batch=getattr(args, "chunks_per_batch", None), ) return result except Exception as e: diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index c37af2102..bd2a6f68d 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface: return graph_client -@lru_cache def create_graph_engine( graph_database_provider, graph_file_path, @@ -35,6 +34,35 @@ def create_graph_engine( graph_database_port="", graph_database_key="", graph_dataset_database_handler="", +): + """ + Wrapper function to call create graph engine with caching. + For a detailed description, see _create_graph_engine. + """ + return _create_graph_engine( + graph_database_provider, + graph_file_path, + graph_database_url, + graph_database_name, + graph_database_username, + graph_database_password, + graph_database_port, + graph_database_key, + graph_dataset_database_handler, + ) + + +@lru_cache +def _create_graph_engine( + graph_database_provider, + graph_file_path, + graph_database_url="", + graph_database_name="", + graph_database_username="", + graph_database_password="", + graph_database_port="", + graph_database_key="", + graph_dataset_database_handler="", ): """ Create a graph engine based on the specified provider type. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py index eb6cbc55a..bccf5020e 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -1,11 +1,13 @@ import os +import aiohttp import asyncio -import requests import base64 import hashlib from uuid import UUID from typing import Optional +from urllib.parse import urlparse from cryptography.fernet import Fernet +from aiohttp import BasicAuth from cognee.infrastructure.databases.graph import get_graph_config from cognee.modules.users.models import User, DatasetDatabase @@ -23,7 +25,6 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): Quality of life improvements: - Allow configuration of different Neo4j Aura plans and regions. - - Requests should be made async, currently a blocking requests library is used. """ @classmethod @@ -49,6 +50,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): graph_db_name = f"{dataset_id}" # Client credentials and encryption + # Note: Should not be used as class variables so that they are not persisted in memory longer than needed client_id = os.environ.get("NEO4J_CLIENT_ID", None) client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) tenant_id = os.environ.get("NEO4J_TENANT_ID", None) @@ -63,22 +65,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." ) - # Make the request with HTTP Basic Auth - def get_aura_token(client_id: str, client_secret: str) -> dict: - url = "https://api.neo4j.io/oauth/token" - data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded - - resp = requests.post(url, data=data, auth=(client_id, client_secret)) - resp.raise_for_status() # raises if the request failed - return resp.json() - - resp = get_aura_token(client_id, client_secret) + resp_token = await cls._get_aura_token(client_id, client_secret) url = "https://api.neo4j.io/v1/instances" headers = { "accept": "application/json", - "Authorization": f"Bearer {resp['access_token']}", + "Authorization": f"Bearer {resp_token['access_token']}", "Content-Type": "application/json", } @@ -96,31 +89,38 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "cloud_provider": "gcp", } - response = requests.post(url, headers=headers, json=payload) + async def _create_database_instance_request(): + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + + resp_create = await _create_database_instance_request() graph_db_name = "neo4j" # Has to be 'neo4j' for Aura - graph_db_url = response.json()["data"]["connection_url"] - graph_db_key = resp["access_token"] - graph_db_username = response.json()["data"]["username"] - graph_db_password = response.json()["data"]["password"] + graph_db_url = resp_create["data"]["connection_url"] + graph_db_key = resp_token["access_token"] + graph_db_username = resp_create["data"]["username"] + graph_db_password = resp_create["data"]["password"] async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): # Poll until the instance is running status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" status = "" for attempt in range(30): # Try for up to ~5 minutes - status_resp = requests.get( - status_url, headers=headers - ) # TODO: Use async requests with httpx - status = status_resp.json()["data"]["status"] - if status.lower() == "running": - return - await asyncio.sleep(10) + async with aiohttp.ClientSession() as session: + async with session.get(status_url, headers=headers) as resp: + resp.raise_for_status() + status_resp = await resp.json() + status = status_resp["data"]["status"] + if status.lower() == "running": + return + await asyncio.sleep(10) raise TimeoutError( f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" ) - instance_id = response.json()["data"]["id"] + instance_id = resp_create["data"]["id"] await _wait_for_neo4j_instance_provisioning(instance_id, headers) encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode()) @@ -165,4 +165,39 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): @classmethod async def delete_dataset(cls, dataset_database: DatasetDatabase): - pass + # Get dataset database information and credentials + dataset_database = await cls.resolve_dataset_connection_info(dataset_database) + + parsed_url = urlparse(dataset_database.graph_database_url) + instance_id = parsed_url.hostname.split(".")[0] + + url = f"https://api.neo4j.io/v1/instances/{instance_id}" + + # Get access token for Neo4j Aura API + # Client credentials + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + resp = await cls._get_aura_token(client_id, client_secret) + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + async with aiohttp.ClientSession() as session: + async with session.delete(url, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + @classmethod + async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + async with aiohttp.ClientSession() as session: + async with session.post( + url, data=data, auth=BasicAuth(client_id, client_secret) + ) as resp: + resp.raise_for_status() + return await resp.json() diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 36c6ef09e..5a7ac64f2 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi from functools import lru_cache -@lru_cache def create_vector_engine( vector_db_provider: str, vector_db_url: str, @@ -17,6 +16,29 @@ def create_vector_engine( vector_dataset_database_handler: str = "", vector_db_username: str = "", vector_db_password: str = "", +): + """ + Wrapper function to call create vector engine with caching. + For a detailed description, see _create_vector_engine. + """ + return _create_vector_engine( + vector_db_provider, + vector_db_url, + vector_db_name, + vector_db_port, + vector_db_key, + vector_dataset_database_handler, + ) + + +@lru_cache +def _create_vector_engine( + vector_db_provider: str, + vector_db_url: str, + vector_db_name: str, + vector_db_port: str = "", + vector_db_key: str = "", + vector_dataset_database_handler: str = "", ): """ Create a vector database engine based on the specified provider. diff --git a/cognee/modules/cognify/config.py b/cognee/modules/cognify/config.py index ec03225e8..223392375 100644 --- a/cognee/modules/cognify/config.py +++ b/cognee/modules/cognify/config.py @@ -9,6 +9,7 @@ class CognifyConfig(BaseSettings): classification_model: object = DefaultContentPrediction summarization_model: object = SummarizedContent triplet_embedding: bool = False + chunks_per_batch: Optional[int] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: @@ -16,6 +17,7 @@ class CognifyConfig(BaseSettings): "classification_model": self.classification_model, "summarization_model": self.summarization_model, "triplet_embedding": self.triplet_embedding, + "chunks_per_batch": self.chunks_per_batch, } diff --git a/cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py index 7654a781a..1301a8eaa 100644 --- a/cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +++ b/cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py @@ -238,6 +238,7 @@ class TestCognifyCommand: ontology_file_path=None, chunker=TextChunker, run_in_background=False, + chunks_per_batch=None, ) @patch("cognee.cli.commands.cognify_command.asyncio.run") diff --git a/cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py index ca27c0f67..466a9e458 100644 --- a/cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +++ b/cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py @@ -262,6 +262,7 @@ class TestCognifyCommandEdgeCases: ontology_file_path=None, chunker=TextChunker, run_in_background=False, + chunks_per_batch=None, ) @patch("cognee.cli.commands.cognify_command.asyncio.run", side_effect=_mock_run) @@ -295,6 +296,7 @@ class TestCognifyCommandEdgeCases: ontology_file_path="/nonexistent/path/ontology.owl", chunker=TextChunker, run_in_background=False, + chunks_per_batch=None, ) @patch("cognee.cli.commands.cognify_command.asyncio.run") @@ -373,6 +375,7 @@ class TestCognifyCommandEdgeCases: ontology_file_path=None, chunker=TextChunker, run_in_background=False, + chunks_per_batch=None, ) diff --git a/cognee/tests/test_permissions.py b/cognee/tests/test_permissions.py index 10696441e..9d949c92b 100644 --- a/cognee/tests/test_permissions.py +++ b/cognee/tests/test_permissions.py @@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None: except Exception: pass - from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) - from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine + from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine - create_graph_engine.cache_clear() - create_vector_engine.cache_clear() + _create_graph_engine.cache_clear() + _create_vector_engine.cache_clear() create_relational_engine.cache_clear() await cognee.prune.prune_data() diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 37b8ae45b..cdb2bbc64 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None: # Engine might not exist yet pass - from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine - from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine + from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) - create_graph_engine.cache_clear() - create_vector_engine.cache_clear() + _create_graph_engine.cache_clear() + _create_vector_engine.cache_clear() create_relational_engine.cache_clear() await cognee.prune.prune_data()