From f58ba86e7c56a9d39a82fa785a53ebed6d5f281b Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 18 Sep 2025 17:07:05 +0100 Subject: [PATCH] feat: add welcome tutorial notebook for new users (#1425) ## Description Update default tutorial: 1. Use tutorial from [notebook_tutorial branch](https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/tutorial.ipynb), specifically - it's .zip version with all necessary data files 2. Use Jupyter Notebook `Notebook` abstractions to read, and map `ipynb` into our Notebook model 3. Dynamically update starter notebook code blocks that reference starter data files, and swap them with local paths to downloaded copies 4. Test coverage | Before | After (storage backend = local) | After (s3) | |--------|---------------------------------|------------| | Screenshot 2025-09-17 at 01 00 58 | Screenshot 2025-09-18 at 13 01
57 | Screenshot 2025-09-18 at 12 56
08 | ## File Replacements ### S3 Demo https://github.com/user-attachments/assets/bd46eec9-ef77-4f69-9ef0-e7d1612ff9b3 --- ### Local FS Demo https://github.com/user-attachments/assets/8251cea0-81b3-4cac-a968-9576c358f334 ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made - - - ## Testing ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues ## Additional Notes ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .env.template | 22 + .gitignore | 1 + cognee/api/v1/prune/prune.py | 4 +- cognee/api/v1/sync/sync.py | 21 +- cognee/base_config.py | 15 + .../graph/kuzu/remote_kuzu_adapter.py | 5 +- .../embeddings/OllamaEmbeddingEngine.py | 5 +- .../files/storage/LocalFileStorage.py | 50 +++ .../files/storage/S3FileStorage.py | 65 ++- .../files/storage/StorageManager.py | 18 + .../modules/cloud/operations/check_api_key.py | 5 +- cognee/modules/data/deletion/prune_system.py | 6 +- .../notebooks/methods/create_notebook.py | 32 ++ cognee/modules/notebooks/models/Notebook.py | 207 ++++++++- cognee/modules/users/methods/create_user.py | 28 +- cognee/shared/cache.py | 346 +++++++++++++++ cognee/shared/utils.py | 12 + .../users/test_tutorial_notebook_creation.py | 399 ++++++++++++++++++ pyproject.toml | 1 + uv.lock | 2 + 20 files changed, 1200 insertions(+), 44 deletions(-) create mode 100644 cognee/shared/cache.py create mode 100644 cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py diff --git a/.env.template b/.env.template index 916a1ef76..781e82428 100644 --- a/.env.template +++ b/.env.template @@ -47,6 +47,28 @@ BAML_LLM_API_VERSION="" # DATA_ROOT_DIRECTORY='/Users//Desktop/cognee/.cognee_data/' # SYSTEM_ROOT_DIRECTORY='/Users//Desktop/cognee/.cognee_system/' +################################################################################ +# ☁️ Storage Backend Settings +################################################################################ +# Configure storage backend (local filesystem or S3) +# STORAGE_BACKEND="local" # Default: uses local filesystem +# +# -- To switch to S3 storage, uncomment and fill these: --------------------- +# STORAGE_BACKEND="s3" +# STORAGE_BUCKET_NAME="your-bucket-name" +# AWS_REGION="us-east-1" +# AWS_ACCESS_KEY_ID="your-access-key" +# AWS_SECRET_ACCESS_KEY="your-secret-key" +# +# -- S3 Root Directories (optional) ----------------------------------------- +# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data" +# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system" +# +# -- Cache Directory (auto-configured for S3) ------------------------------- +# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache +# To override the automatic S3 cache location, uncomment: +# CACHE_ROOT_DIRECTORY="s3://your-bucket/cognee/cache" + ################################################################################ # 🗄️ Relational database settings ################################################################################ diff --git a/.gitignore b/.gitignore index 8441d4b24..ff42efc97 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ cognee/cache/ # Default cognee system directory, used in development .cognee_system/ .data_storage/ +.cognee_cache/ .artifacts/ .anon_id diff --git a/cognee/api/v1/prune/prune.py b/cognee/api/v1/prune/prune.py index 61438e54d..206d8d8cb 100644 --- a/cognee/api/v1/prune/prune.py +++ b/cognee/api/v1/prune/prune.py @@ -7,8 +7,8 @@ class prune: await _prune_data() @staticmethod - async def prune_system(graph=True, vector=True, metadata=False): - await _prune_system(graph, vector, metadata) + async def prune_system(graph=True, vector=True, metadata=False, cache=True): + await _prune_system(graph, vector, metadata, cache) if __name__ == "__main__": diff --git a/cognee/api/v1/sync/sync.py b/cognee/api/v1/sync/sync.py index 54339c1c4..58ff3a34b 100644 --- a/cognee/api/v1/sync/sync.py +++ b/cognee/api/v1/sync/sync.py @@ -23,6 +23,7 @@ from cognee.modules.sync.methods import ( mark_sync_completed, mark_sync_failed, ) +from cognee.shared.utils import create_secure_ssl_context logger = get_logger("sync") @@ -583,7 +584,9 @@ async def _check_hashes_diff( logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}") try: - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: async with session.post(url, json=payload.dict(), headers=headers) as response: if response.status == 200: data = await response.json() @@ -630,7 +633,9 @@ async def _download_missing_files( headers = {"X-Api-Key": auth_token} - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: for file_hash in hashes_missing_on_local: try: # Download file from cloud by hash @@ -749,7 +754,9 @@ async def _upload_missing_files( headers = {"X-Api-Key": auth_token} - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: for file_info in files_to_upload: try: file_dir = os.path.dirname(file_info.raw_data_location) @@ -809,7 +816,9 @@ async def _prune_cloud_dataset( logger.info("Pruning cloud dataset to match local state") try: - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: async with session.put(url, json=payload.dict(), headers=headers) as response: if response.status == 200: data = await response.json() @@ -852,7 +861,9 @@ async def _trigger_remote_cognify( logger.info(f"Triggering cognify processing for dataset {dataset_id}") try: - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: async with session.post(url, json=payload, headers=headers) as response: if response.status == 200: data = await response.json() diff --git a/cognee/base_config.py b/cognee/base_config.py index 940846128..2e2afb2de 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -10,13 +10,27 @@ import pydantic class BaseConfig(BaseSettings): data_root_directory: str = get_absolute_path(".data_storage") system_root_directory: str = get_absolute_path(".cognee_system") + cache_root_directory: str = get_absolute_path(".cognee_cache") monitoring_tool: object = Observer.LANGFUSE @pydantic.model_validator(mode="after") def validate_paths(self): + # Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically + # I'll remove this after we update documentation for S3 storage + # Auto-configure cache root directory for S3 storage if not explicitly set + storage_backend = os.getenv("STORAGE_BACKEND", "").lower() + cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY") + + if storage_backend == "s3" and not cache_root_env: + # Auto-generate S3 cache path when using S3 storage + bucket_name = os.getenv("STORAGE_BUCKET_NAME") + if bucket_name: + self.cache_root_directory = f"s3://{bucket_name}/cognee/cache" + # Require absolute paths for root directories self.data_root_directory = ensure_absolute_path(self.data_root_directory) self.system_root_directory = ensure_absolute_path(self.system_root_directory) + self.cache_root_directory = ensure_absolute_path(self.cache_root_directory) return self langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") @@ -31,6 +45,7 @@ class BaseConfig(BaseSettings): "data_root_directory": self.data_root_directory, "system_root_directory": self.system_root_directory, "monitoring_tool": self.monitoring_tool, + "cache_root_directory": self.cache_root_directory, } diff --git a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py index c75b70f75..260043743 100644 --- a/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py @@ -7,6 +7,7 @@ import aiohttp from uuid import UUID from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter +from cognee.shared.utils import create_secure_ssl_context logger = get_logger() @@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter): async def _get_session(self) -> aiohttp.ClientSession: """Get or create an aiohttp session.""" if self._session is None or self._session.closed: - self._session = aiohttp.ClientSession() + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + self._session = aiohttp.ClientSession(connector=connector) return self._session async def close(self): diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index cf56dba1f..e6e590597 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im embedding_rate_limit_async, embedding_sleep_and_retry_async, ) +from cognee.shared.utils import create_secure_ssl_context logger = get_logger("OllamaEmbeddingEngine") @@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine): if api_key: headers["Authorization"] = f"Bearer {api_key}" - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: async with session.post( self.endpoint, json=payload, headers=headers, timeout=60.0 ) as response: diff --git a/cognee/infrastructure/files/storage/LocalFileStorage.py b/cognee/infrastructure/files/storage/LocalFileStorage.py index 34e97d827..c48d9a45d 100644 --- a/cognee/infrastructure/files/storage/LocalFileStorage.py +++ b/cognee/infrastructure/files/storage/LocalFileStorage.py @@ -253,6 +253,56 @@ class LocalFileStorage(Storage): if os.path.exists(full_file_path): os.remove(full_file_path) + def list_files(self, directory_path: str, recursive: bool = False) -> list[str]: + """ + List all files in the specified directory. + + Parameters: + ----------- + - directory_path (str): The directory path to list files from + - recursive (bool): If True, list files recursively in subdirectories + + Returns: + -------- + - list[str]: List of file paths relative to the storage root + """ + from pathlib import Path + + parsed_storage_path = get_parsed_path(self.storage_path) + + if directory_path: + full_directory_path = os.path.join(parsed_storage_path, directory_path) + else: + full_directory_path = parsed_storage_path + + directory_pathlib = Path(full_directory_path) + + if not directory_pathlib.exists() or not directory_pathlib.is_dir(): + return [] + + files = [] + + if recursive: + # Use rglob for recursive search + for file_path in directory_pathlib.rglob("*"): + if file_path.is_file(): + # Get relative path from storage root + relative_path = os.path.relpath(str(file_path), parsed_storage_path) + # Normalize path separators for consistency + relative_path = relative_path.replace(os.sep, "/") + files.append(relative_path) + else: + # Use iterdir for just immediate directory + for file_path in directory_pathlib.iterdir(): + if file_path.is_file(): + # Get relative path from storage root + relative_path = os.path.relpath(str(file_path), parsed_storage_path) + # Normalize path separators for consistency + relative_path = relative_path.replace(os.sep, "/") + files.append(relative_path) + + return files + def remove_all(self, tree_path: str = None): """ Remove an entire directory tree at the specified path, including all files and diff --git a/cognee/infrastructure/files/storage/S3FileStorage.py b/cognee/infrastructure/files/storage/S3FileStorage.py index 6218d6240..ca2a73291 100644 --- a/cognee/infrastructure/files/storage/S3FileStorage.py +++ b/cognee/infrastructure/files/storage/S3FileStorage.py @@ -155,21 +155,19 @@ class S3FileStorage(Storage): """ Ensure that the specified directory exists, creating it if necessary. - If the directory already exists, no action is taken. + For S3 storage, this is a no-op since directories are created implicitly + when files are written to paths. S3 doesn't have actual directories, + just object keys with prefixes that appear as directories. Parameters: ----------- - directory_path (str): The path of the directory to check or create. """ - if not directory_path.strip(): - directory_path = self.storage_path.replace("s3://", "") - - def ensure_directory(): - if not self.s3.exists(directory_path): - self.s3.makedirs(directory_path, exist_ok=True) - - await run_async(ensure_directory) + # In S3, directories don't exist as separate entities - they're just prefixes + # When you write a file to s3://bucket/path/to/file.txt, the "directories" + # path/ and path/to/ are implicitly created. No explicit action needed. + pass async def copy_file(self, source_file_path: str, destination_file_path: str): """ @@ -213,6 +211,55 @@ class S3FileStorage(Storage): await run_async(remove_file) + async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]: + """ + List all files in the specified directory. + + Parameters: + ----------- + - directory_path (str): The directory path to list files from + - recursive (bool): If True, list files recursively in subdirectories + + Returns: + -------- + - list[str]: List of file paths relative to the storage root + """ + + def list_files_sync(): + if directory_path: + # Combine storage path with directory path + full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path) + else: + full_path = self.storage_path.replace("s3://", "") + + if recursive: + # Use ** for recursive search + pattern = f"{full_path}/**" + else: + # Just files in the immediate directory + pattern = f"{full_path}/*" + + # Use s3fs glob to find files + try: + all_paths = self.s3.glob(pattern) + # Filter to only files (not directories) + files = [path for path in all_paths if self.s3.isfile(path)] + + # Convert back to relative paths from storage root + storage_prefix = self.storage_path.replace("s3://", "") + relative_files = [] + for file_path in files: + if file_path.startswith(storage_prefix): + relative_path = file_path[len(storage_prefix) :].lstrip("/") + relative_files.append(relative_path) + + return relative_files + except Exception: + # If directory doesn't exist or other error, return empty list + return [] + + return await run_async(list_files_sync) + async def remove_all(self, tree_path: str): """ Remove an entire directory tree at the specified path, including all files and diff --git a/cognee/infrastructure/files/storage/StorageManager.py b/cognee/infrastructure/files/storage/StorageManager.py index 5ac45f14f..e65a4ecf7 100644 --- a/cognee/infrastructure/files/storage/StorageManager.py +++ b/cognee/infrastructure/files/storage/StorageManager.py @@ -135,6 +135,24 @@ class StorageManager: else: return self.storage.remove(file_path) + async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]: + """ + List all files in the specified directory. + + Parameters: + ----------- + - directory_path (str): The directory path to list files from + - recursive (bool): If True, list files recursively in subdirectories + + Returns: + -------- + - list[str]: List of file paths relative to the storage root + """ + if inspect.iscoroutinefunction(self.storage.list_files): + return await self.storage.list_files(directory_path, recursive) + else: + return self.storage.list_files(directory_path, recursive) + async def remove_all(self, tree_path: str = None): """ Remove an entire directory tree at the specified path, including all files and diff --git a/cognee/modules/cloud/operations/check_api_key.py b/cognee/modules/cloud/operations/check_api_key.py index 67c1eac3c..6d986801e 100644 --- a/cognee/modules/cloud/operations/check_api_key.py +++ b/cognee/modules/cloud/operations/check_api_key.py @@ -1,6 +1,7 @@ import aiohttp from cognee.modules.cloud.exceptions import CloudConnectionError +from cognee.shared.utils import create_secure_ssl_context async def check_api_key(auth_token: str): @@ -10,7 +11,9 @@ async def check_api_key(auth_token: str): headers = {"X-Api-Key": auth_token} try: - async with aiohttp.ClientSession() as session: + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: async with session.post(url, headers=headers) as response: if response.status == 200: return diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index 5bbd7c22f..a1b60988f 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -1,9 +1,10 @@ from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.shared.cache import delete_cache -async def prune_system(graph=True, vector=True, metadata=True): +async def prune_system(graph=True, vector=True, metadata=True, cache=True): if graph: graph_engine = await get_graph_engine() await graph_engine.delete_graph() @@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True): if metadata: db_engine = get_relational_engine() await db_engine.delete_database() + + if cache: + await delete_cache() diff --git a/cognee/modules/notebooks/methods/create_notebook.py b/cognee/modules/notebooks/methods/create_notebook.py index b4915da23..22ed047f1 100644 --- a/cognee/modules/notebooks/methods/create_notebook.py +++ b/cognee/modules/notebooks/methods/create_notebook.py @@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session from ..models.Notebook import Notebook, NotebookCell +async def _create_tutorial_notebook( + user_id: UUID, session: AsyncSession, force_refresh: bool = False +) -> None: + """ + Create the default tutorial notebook for new users. + Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip + """ + TUTORIAL_ZIP_URL = ( + "https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip" + ) + + try: + # Create notebook from remote zip file (includes notebook + data files) + notebook = await Notebook.from_ipynb_zip_url( + zip_url=TUTORIAL_ZIP_URL, + owner_id=user_id, + notebook_filename="tutorial.ipynb", + name="Python Development with Cognee Tutorial 🧠", + deletable=False, + force=force_refresh, + ) + + # Add to session and commit + session.add(notebook) + await session.commit() + + except Exception as e: + print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}") + + raise e + + @with_async_session async def create_notebook( user_id: UUID, diff --git a/cognee/modules/notebooks/models/Notebook.py b/cognee/modules/notebooks/models/Notebook.py index 7bf26d4a7..68d85a07e 100644 --- a/cognee/modules/notebooks/models/Notebook.py +++ b/cognee/modules/notebooks/models/Notebook.py @@ -1,13 +1,24 @@ import json -from typing import List, Literal +import nbformat +import asyncio +from nbformat.notebooknode import NotebookNode +from typing import List, Literal, Optional, cast, Tuple from uuid import uuid4, UUID as UUID_t from pydantic import BaseModel, ConfigDict from datetime import datetime, timezone from fastapi.encoders import jsonable_encoder from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator from sqlalchemy.orm import mapped_column, Mapped +from pathlib import Path from cognee.infrastructure.databases.relational import Base +from cognee.shared.cache import ( + download_and_extract_zip, + get_tutorial_data_dir, + generate_content_hash, +) +from cognee.infrastructure.files.storage.get_file_storage import get_file_storage +from cognee.base_config import get_base_config class NotebookCell(BaseModel): @@ -51,3 +62,197 @@ class Notebook(Base): deletable: Mapped[bool] = mapped_column(Boolean, default=True) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + + @classmethod + async def from_ipynb_zip_url( + cls, + zip_url: str, + owner_id: UUID_t, + notebook_filename: str = "tutorial.ipynb", + name: Optional[str] = None, + deletable: bool = True, + force: bool = False, + ) -> "Notebook": + """ + Create a Notebook instance from a remote zip file containing notebook + data files. + + Args: + zip_url: Remote URL to fetch the .zip file from + owner_id: UUID of the notebook owner + notebook_filename: Name of the .ipynb file within the zip + name: Optional custom name for the notebook + deletable: Whether the notebook can be deleted + force: If True, re-download even if already cached + + Returns: + Notebook instance + """ + # Generate a cache key based on the zip URL + content_hash = generate_content_hash(zip_url, notebook_filename) + + # Download and extract the zip file to tutorial_data/{content_hash} + try: + extracted_cache_dir = await download_and_extract_zip( + url=zip_url, + cache_dir_name=f"tutorial_data/{content_hash}", + version_or_hash=content_hash, + force=force, + ) + except Exception as e: + raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e + + # Use cache system to access the notebook file + from cognee.shared.cache import cache_file_exists, read_cache_file + + notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}" + + # Check if the notebook file exists in cache + if not await cache_file_exists(notebook_file_path): + raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip") + + # Read and parse the notebook using cache system + async with await read_cache_file(notebook_file_path, encoding="utf-8") as f: + notebook_content = await asyncio.to_thread(f.read) + notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable) + + # Update file paths in notebook cells to point to actual cached data files + await cls._update_file_paths_in_cells(notebook, extracted_cache_dir) + + return notebook + + @staticmethod + async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None: + """ + Update file paths in code cells to use actual cached data files. + Works with both local filesystem and S3 storage. + + Args: + notebook: Parsed Notebook instance with cells to update + cache_dir: Path to the cached tutorial directory containing data files + """ + import re + from cognee.shared.cache import list_cache_files, cache_file_exists + from cognee.shared.logging_utils import get_logger + + logger = get_logger() + + # Look for data files in the data subdirectory + data_dir = f"{cache_dir}/data" + + try: + # Get all data files in the cache directory using cache system + data_files = {} + if await cache_file_exists(data_dir): + file_list = await list_cache_files(data_dir) + else: + file_list = [] + + for file_path in file_list: + # Extract just the filename + filename = file_path.split("/")[-1] + # Use the file path as provided by cache system + data_files[filename] = file_path + + except Exception as e: + # If we can't list files, skip updating paths + logger.error(f"Error listing data files in {data_dir}: {e}") + return + + # Pattern to match file://data/filename patterns in code cells + file_pattern = r'"file://data/([^"]+)"' + + def replace_path(match): + filename = match.group(1) + if filename in data_files: + file_path = data_files[filename] + # For local filesystem, preserve file:// prefix + if not file_path.startswith("s3://"): + return f'"file://{file_path}"' + else: + # For S3, return the S3 URL as-is + return f'"{file_path}"' + return match.group(0) # Keep original if file not found + + # Update only code cells + updated_cells = 0 + for cell in notebook.cells: + if cell.type == "code": + original_content = cell.content + # Update file paths in the cell content + cell.content = re.sub(file_pattern, replace_path, cell.content) + if original_content != cell.content: + updated_cells += 1 + + # Log summary of updates (useful for monitoring) + if updated_cells > 0: + logger.info(f"Updated file paths in {updated_cells} notebook cells") + + @classmethod + def from_ipynb_string( + cls, + notebook_content: str, + owner_id: UUID_t, + name: Optional[str] = None, + deletable: bool = True, + ) -> "Notebook": + """ + Create a Notebook instance from Jupyter notebook string content. + + Args: + notebook_content: Raw Jupyter notebook content as string + owner_id: UUID of the notebook owner + name: Optional custom name for the notebook + deletable: Whether the notebook can be deleted + + Returns: + Notebook instance ready to be saved to database + """ + # Parse and validate the Jupyter notebook using nbformat + # Note: nbformat.reads() has loose typing, so we cast to NotebookNode + jupyter_nb = cast( + NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT) + ) + + # Convert Jupyter cells to NotebookCell objects + cells = [] + for jupyter_cell in jupyter_nb.cells: + # Each cell is also a NotebookNode with dynamic attributes + cell = cast(NotebookNode, jupyter_cell) + # Skip raw cells as they're not supported in our model + if cell.cell_type == "raw": + continue + + # Get the source content + content = cell.source + + # Generate a name based on content or cell index + cell_name = cls._generate_cell_name(cell) + + # Map cell types (jupyter uses "code"/"markdown", we use same) + cell_type = "code" if cell.cell_type == "code" else "markdown" + + cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content)) + + # Extract notebook name from metadata if not provided + if name is None: + kernelspec = jupyter_nb.metadata.get("kernelspec", {}) + name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook") + + return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable) + + @staticmethod + def _generate_cell_name(jupyter_cell: NotebookNode) -> str: + """Generate a meaningful name for a notebook cell using nbformat cell.""" + if jupyter_cell.cell_type == "markdown": + # Try to extract a title from markdown headers + content = jupyter_cell.source + + lines = content.strip().split("\n") + if lines and lines[0].startswith("#"): + # Extract header text, clean it up + header = lines[0].lstrip("#").strip() + return header[:50] if len(header) > 50 else header + else: + return "Markdown Cell" + else: + return "Code Cell" diff --git a/cognee/modules/users/methods/create_user.py b/cognee/modules/users/methods/create_user.py index fd96dc374..e3f24ccad 100644 --- a/cognee/modules/users/methods/create_user.py +++ b/cognee/modules/users/methods/create_user.py @@ -1,9 +1,10 @@ -from uuid import uuid4 +from uuid import UUID, uuid4 from fastapi_users.exceptions import UserAlreadyExists +from sqlalchemy.ext.asyncio import AsyncSession from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.notebooks.methods import create_notebook -from cognee.modules.notebooks.models.Notebook import NotebookCell +from cognee.modules.notebooks.models.Notebook import Notebook +from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook from cognee.modules.users.exceptions import TenantNotFoundError from cognee.modules.users.get_user_manager import get_user_manager_context from cognee.modules.users.get_user_db import get_user_db_context @@ -60,26 +61,7 @@ async def create_user( if auto_login: await session.refresh(user) - await create_notebook( - user_id=user.id, - notebook_name="Welcome to cognee 🧠", - cells=[ - NotebookCell( - id=uuid4(), - name="Welcome", - content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.", - type="markdown", - ), - NotebookCell( - id=uuid4(), - name="Example", - content="", - type="code", - ), - ], - deletable=False, - session=session, - ) + await _create_tutorial_notebook(user.id, session) return user except UserAlreadyExists as error: diff --git a/cognee/shared/cache.py b/cognee/shared/cache.py new file mode 100644 index 000000000..c645b9bef --- /dev/null +++ b/cognee/shared/cache.py @@ -0,0 +1,346 @@ +""" +Storage-aware cache management utilities for Cognee. + +This module provides cache functionality that works with both local and cloud storage +backends (like S3) through the StorageManager abstraction. +""" + +import hashlib +import zipfile +import asyncio +from typing import Optional, Tuple +import aiohttp +import logging +from io import BytesIO + +from cognee.base_config import get_base_config +from cognee.infrastructure.files.storage.get_file_storage import get_file_storage +from cognee.infrastructure.files.storage.StorageManager import StorageManager +from cognee.shared.utils import create_secure_ssl_context + +logger = logging.getLogger(__name__) + + +class StorageAwareCache: + """ + A cache manager that works with different storage backends (local, S3, etc.) + """ + + def __init__(self, cache_subdir: str = "cache"): + """ + Initialize the cache manager. + + Args: + cache_subdir: Subdirectory name within the system root for caching + """ + self.base_config = get_base_config() + # Since we're using cache_root_directory, don't add extra cache prefix + self.cache_base_path = "" + self.storage_manager: StorageManager = get_file_storage( + self.base_config.cache_root_directory + ) + + # Print absolute path + storage_path = self.storage_manager.storage.storage_path + if storage_path.startswith("s3://"): + absolute_path = storage_path # S3 paths are already absolute + else: + import os + + absolute_path = os.path.abspath(storage_path) + logger.info(f"Storage manager absolute path: {absolute_path}") + + async def get_cache_dir(self) -> str: + """Get the base cache directory path.""" + cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty + await self.storage_manager.ensure_directory_exists(cache_path) + return cache_path + + async def get_cache_subdir(self, name: str) -> str: + """Get a specific cache subdirectory.""" + if self.cache_base_path: + cache_path = f"{self.cache_base_path}/{name}" + else: + cache_path = name + await self.storage_manager.ensure_directory_exists(cache_path) + + # Return the absolute path based on storage system + if self.storage_manager.storage.storage_path.startswith("s3://"): + return cache_path + elif hasattr(self.storage_manager.storage, "storage_path"): + return f"{self.storage_manager.storage.storage_path}/{cache_path}" + else: + # Fallback for other storage types + return cache_path + + async def delete_cache(self): + """Delete the entire cache directory.""" + logger.info("Deleting cache...") + try: + await self.storage_manager.remove_all(self.cache_base_path) + logger.info("✓ Cache deleted successfully!") + except Exception as e: + logger.error(f"Error deleting cache: {e}") + raise + + async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool: + """Check if cached content is valid for the given version/hash.""" + version_file = f"{cache_dir}/version.txt" + + if not await self.storage_manager.file_exists(version_file): + return False + + try: + async with self.storage_manager.open(version_file, "r") as f: + cached_version = (await asyncio.to_thread(f.read)).strip() + return cached_version == version_or_hash + except Exception as e: + logger.debug(f"Error checking cache validity: {e}") + return False + + async def _clear_cache(self, cache_dir: str) -> None: + """Clear a cache directory.""" + try: + await self.storage_manager.remove_all(cache_dir) + except Exception as e: + logger.debug(f"Error clearing cache directory {cache_dir}: {e}") + + async def _check_remote_content_freshness( + self, url: str, cache_dir: str + ) -> Tuple[bool, Optional[str]]: + """ + Check if remote content is fresher than cached version using HTTP headers. + + Returns: + Tuple of (is_fresh: bool, new_identifier: Optional[str]) + """ + try: + # Make a HEAD request to check headers without downloading + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: + async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response: + response.raise_for_status() + + # Try ETag first (most reliable) + etag = response.headers.get("ETag", "").strip('"') + last_modified = response.headers.get("Last-Modified", "") + + # Use ETag if available, otherwise Last-Modified + remote_identifier = etag if etag else last_modified + + if not remote_identifier: + logger.debug("No freshness headers available, cannot check for updates") + return True, None # Assume fresh if no headers + + # Check cached identifier + identifier_file = f"{cache_dir}/content_id.txt" + if await self.storage_manager.file_exists(identifier_file): + async with self.storage_manager.open(identifier_file, "r") as f: + cached_identifier = (await asyncio.to_thread(f.read)).strip() + if cached_identifier == remote_identifier: + logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)") + return True, None + else: + logger.info( + f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)" + ) + return False, remote_identifier + else: + # No cached identifier, treat as stale + return False, remote_identifier + + except Exception as e: + logger.debug(f"Could not check remote freshness: {e}") + return True, None # Assume fresh if we can't check + + async def download_and_extract_zip( + self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False + ) -> str: + """ + Download a zip file and extract it to cache directory with content freshness checking. + + Args: + url: URL to download zip file from + cache_subdir_name: Name of the cache subdirectory + version_or_hash: Version string or content hash for cache validation + force: If True, re-download even if already cached + + Returns: + Path to the cached directory + """ + cache_dir = await self.get_cache_subdir(cache_subdir_name) + + # Check if already cached and valid + if not force and await self._is_cache_valid(cache_dir, version_or_hash): + # Also check if remote content has changed + is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir) + if is_fresh: + logger.debug(f"Content already cached and fresh for version {version_or_hash}") + return cache_dir + else: + logger.info("Cached content is stale, updating...") + + # Clear old cache if it exists + await self._clear_cache(cache_dir) + + logger.info(f"Downloading content from {url}...") + + # Download the zip file + zip_content = BytesIO() + etag = "" + last_modified = "" + ssl_context = create_secure_ssl_context() + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response: + response.raise_for_status() + + # Extract headers before consuming response + etag = response.headers.get("ETag", "").strip('"') + last_modified = response.headers.get("Last-Modified", "") + + # Read the response content + async for chunk in response.content.iter_chunked(8192): + zip_content.write(chunk) + zip_content.seek(0) + + # Extract the archive + await self.storage_manager.ensure_directory_exists(cache_dir) + + # Extract files and store them using StorageManager + with zipfile.ZipFile(zip_content, "r") as zip_file: + for file_info in zip_file.infolist(): + if file_info.is_dir(): + # Create directory + dir_path = f"{cache_dir}/{file_info.filename}" + await self.storage_manager.ensure_directory_exists(dir_path) + else: + # Extract and store file + file_data = zip_file.read(file_info.filename) + file_path = f"{cache_dir}/{file_info.filename}" + await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True) + + # Write version info for future cache validation + version_file = f"{cache_dir}/version.txt" + await self.storage_manager.store(version_file, version_or_hash, overwrite=True) + + # Store content identifier from response headers for freshness checking + content_identifier = etag if etag else last_modified + + if content_identifier: + identifier_file = f"{cache_dir}/content_id.txt" + await self.storage_manager.store(identifier_file, content_identifier, overwrite=True) + logger.debug(f"Stored content identifier: {content_identifier[:20]}...") + + logger.info("✓ Content downloaded and cached successfully!") + return cache_dir + + async def file_exists(self, file_path: str) -> bool: + """Check if a file exists in cache storage.""" + return await self.storage_manager.file_exists(file_path) + + async def read_file(self, file_path: str, encoding: str = "utf-8"): + """Read a file from cache storage.""" + return self.storage_manager.open(file_path, encoding=encoding) + + async def list_files(self, directory_path: str): + """List files in a cache directory.""" + try: + file_list = await self.storage_manager.list_files(directory_path) + + # For S3 storage, convert relative paths to full S3 URLs + if self.storage_manager.storage.storage_path.startswith("s3://"): + full_paths = [] + for file_path in file_list: + full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}" + full_paths.append(full_s3_path) + return full_paths + else: + # For local storage, return absolute paths + storage_path = self.storage_manager.storage.storage_path + if not storage_path.startswith("/"): + import os + + storage_path = os.path.abspath(storage_path) + + full_paths = [] + for file_path in file_list: + if file_path.startswith("/"): + full_paths.append(file_path) # Already absolute + else: + full_paths.append(f"{storage_path}/{file_path}") + return full_paths + + except Exception as e: + logger.debug(f"Error listing files in {directory_path}: {e}") + return [] + + +# Convenience functions that maintain API compatibility +_cache_manager = None + + +def get_cache_manager() -> StorageAwareCache: + """Get a singleton cache manager instance.""" + global _cache_manager + if _cache_manager is None: + _cache_manager = StorageAwareCache() + return _cache_manager + + +def generate_content_hash(url: str, additional_data: str = "") -> str: + """Generate a content hash from URL and optional additional data.""" + content = f"{url}:{additional_data}" + return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability + + +# Async wrapper functions for backward compatibility +async def delete_cache(): + """Delete the Cognee cache directory.""" + cache_manager = get_cache_manager() + await cache_manager.delete_cache() + + +async def get_cognee_cache_dir() -> str: + """Get the base Cognee cache directory.""" + cache_manager = get_cache_manager() + return await cache_manager.get_cache_dir() + + +async def get_cache_subdir(name: str) -> str: + """Get a specific cache subdirectory.""" + cache_manager = get_cache_manager() + return await cache_manager.get_cache_subdir(name) + + +async def download_and_extract_zip( + url: str, cache_dir_name: str, version_or_hash: str, force: bool = False +) -> str: + """Download a zip file and extract it to cache directory.""" + cache_manager = get_cache_manager() + return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force) + + +async def get_tutorial_data_dir() -> str: + """Get the tutorial data cache directory.""" + return await get_cache_subdir("tutorial_data") + + +# Cache file operations +async def cache_file_exists(file_path: str) -> bool: + """Check if a file exists in cache storage.""" + cache_manager = get_cache_manager() + return await cache_manager.file_exists(file_path) + + +async def read_cache_file(file_path: str, encoding: str = "utf-8"): + """Read a file from cache storage.""" + cache_manager = get_cache_manager() + return await cache_manager.read_file(file_path, encoding) + + +async def list_cache_files(directory_path: str): + """List files in a cache directory.""" + cache_manager = get_cache_manager() + return await cache_manager.list_files(directory_path) diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index fb4193a8c..6ecdbc8f1 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -1,6 +1,7 @@ """This module contains utility functions for the cognee.""" import os +import ssl import requests from datetime import datetime, timezone import matplotlib.pyplot as plt @@ -18,6 +19,17 @@ from cognee.infrastructure.databases.graph import get_graph_engine proxy_url = "https://test.prometh.ai" +def create_secure_ssl_context() -> ssl.SSLContext: + """ + Create a secure SSL context. + + By default, use the system's certificate store. + If users report SSL issues, I'm keeping this open in case we need to switch to: + ssl.create_default_context(cafile=certifi.where()) + """ + return ssl.create_default_context() + + def get_entities(tagged_tokens): import nltk diff --git a/cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py b/cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py new file mode 100644 index 000000000..e89b7d4a7 --- /dev/null +++ b/cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py @@ -0,0 +1,399 @@ +import json +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import hashlib +import time +from uuid import uuid4 +from sqlalchemy.ext.asyncio import AsyncSession +from pathlib import Path +import zipfile +from cognee.shared.cache import get_tutorial_data_dir + +from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook +from cognee.modules.notebooks.models.Notebook import Notebook +import cognee +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +# Module-level fixtures available to all test classes +@pytest.fixture +def mock_session(): + """Mock database session.""" + session = AsyncMock(spec=AsyncSession) + session.add = MagicMock() + session.commit = AsyncMock() + return session + + +@pytest.fixture +def sample_jupyter_notebook(): + """Sample Jupyter notebook content for testing.""" + return { + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": ["# Tutorial Introduction\n", "\n", "This is a tutorial notebook."], + }, + { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": ["import cognee\n", "print('Hello, Cognee!')"], + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Step 1: Data Ingestion\n", "\n", "Let's add some data."], + }, + { + "cell_type": "code", + "execution_count": None, + "metadata": {}, + "outputs": [], + "source": ["# Add your data here\n", "# await cognee.add('data.txt')"], + }, + { + "cell_type": "raw", + "metadata": {}, + "source": ["This is a raw cell that should be skipped"], + }, + ], + "metadata": { + "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"} + }, + "nbformat": 4, + "nbformat_minor": 4, + } + + +class TestTutorialNotebookCreation: + """Test cases for tutorial notebook creation functionality.""" + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_string_success(self, sample_jupyter_notebook): + """Test successful creation of notebook from JSON string.""" + notebook_json = json.dumps(sample_jupyter_notebook) + user_id = uuid4() + + notebook = Notebook.from_ipynb_string( + notebook_content=notebook_json, owner_id=user_id, name="String Test Notebook" + ) + + assert notebook.owner_id == user_id + assert notebook.name == "String Test Notebook" + assert len(notebook.cells) == 4 # Should skip the raw cell + assert notebook.cells[0].type == "markdown" + assert notebook.cells[1].type == "code" + + @pytest.mark.asyncio + async def test_notebook_cell_name_generation(self, sample_jupyter_notebook): + """Test that cell names are generated correctly from markdown headers.""" + user_id = uuid4() + notebook_json = json.dumps(sample_jupyter_notebook) + + notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id) + + # Check markdown header extraction + assert notebook.cells[0].name == "Tutorial Introduction" + assert notebook.cells[2].name == "Step 1: Data Ingestion" + + # Check code cell naming + assert notebook.cells[1].name == "Code Cell" + assert notebook.cells[3].name == "Code Cell" + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_string_with_default_name(self, sample_jupyter_notebook): + """Test notebook creation uses kernelspec display_name when no name provided.""" + user_id = uuid4() + notebook_json = json.dumps(sample_jupyter_notebook) + + notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id) + + assert notebook.name == "Python 3" # From kernelspec.display_name + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_string_fallback_name(self): + """Test fallback naming when kernelspec is missing.""" + minimal_notebook = { + "cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Test"]}], + "metadata": {}, # No kernelspec + "nbformat": 4, + "nbformat_minor": 4, + } + + user_id = uuid4() + notebook_json = json.dumps(minimal_notebook) + + notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id) + + assert notebook.name == "Imported Notebook" # Fallback name + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_string_invalid_json(self): + """Test error handling for invalid JSON.""" + user_id = uuid4() + invalid_json = "{ invalid json content" + + from nbformat.reader import NotJSONError + + with pytest.raises(NotJSONError): + Notebook.from_ipynb_string(notebook_content=invalid_json, owner_id=user_id) + + @pytest.mark.asyncio + @patch.object(Notebook, "from_ipynb_zip_url") + async def test_create_tutorial_notebook_error_propagated(self, mock_from_zip_url, mock_session): + """Test that errors are propagated when zip fetch fails.""" + user_id = uuid4() + mock_from_zip_url.side_effect = Exception("Network error") + + # Should raise the exception (not catch it) + with pytest.raises(Exception, match="Network error"): + await _create_tutorial_notebook(user_id, mock_session) + + # Verify error handling path was taken + mock_from_zip_url.assert_called_once() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + def test_generate_cell_name_code_cell(self): + """Test cell name generation for code cells.""" + from nbformat.notebooknode import NotebookNode + + mock_cell = NotebookNode( + {"cell_type": "code", "source": 'import pandas as pd\nprint("Hello world")'} + ) + + result = Notebook._generate_cell_name(mock_cell) + assert result == "Code Cell" + + +class TestTutorialNotebookZipFunctionality: + """Test cases for zip-based tutorial functionality.""" + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_zip_url_missing_notebook( + self, + ): + """Test error handling when notebook file is missing from zip.""" + user_id = uuid4() + + with pytest.raises( + FileNotFoundError, + match="Notebook file 'super_random_tutorial_name.ipynb' not found in zip", + ): + await Notebook.from_ipynb_zip_url( + zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip", + owner_id=user_id, + notebook_filename="super_random_tutorial_name.ipynb", + ) + + @pytest.mark.asyncio + async def test_notebook_from_ipynb_zip_url_download_failure(self): + """Test error handling when zip download fails.""" + user_id = uuid4() + with pytest.raises(RuntimeError, match="Failed to download tutorial zip"): + await Notebook.from_ipynb_zip_url( + zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/nonexistent_tutorial_name.zip", + owner_id=user_id, + ) + + @pytest.mark.asyncio + async def test_create_tutorial_notebook_zip_success(self, mock_session): + """Test successful tutorial notebook creation with zip.""" + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + user_id = uuid4() + + # Check that tutorial data directory is empty using storage-aware method + tutorial_data_dir_path = await get_tutorial_data_dir() + tutorial_data_dir = Path(tutorial_data_dir_path) + if tutorial_data_dir.exists(): + assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty" + + await _create_tutorial_notebook(user_id, mock_session) + + items = list(tutorial_data_dir.iterdir()) + assert len(items) == 1, "Tutorial data directory should contain exactly one item" + assert items[0].is_dir(), "Tutorial data directory item should be a directory" + + # Verify the structure inside the tutorial directory + tutorial_dir = items[0] + + # Check for tutorial.ipynb file + notebook_file = tutorial_dir / "tutorial.ipynb" + assert notebook_file.exists(), f"tutorial.ipynb should exist in {tutorial_dir}" + assert notebook_file.is_file(), "tutorial.ipynb should be a file" + + # Check for data subfolder with contents + data_folder = tutorial_dir / "data" + assert data_folder.exists(), f"data subfolder should exist in {tutorial_dir}" + assert data_folder.is_dir(), "data should be a directory" + + data_items = list(data_folder.iterdir()) + assert len(data_items) > 0, ( + f"data folder should contain files, but found {len(data_items)} items" + ) + + @pytest.mark.asyncio + async def test_create_tutorial_notebook_with_force_refresh(self, mock_session): + """Test tutorial notebook creation with force refresh.""" + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + user_id = uuid4() + + # Check that tutorial data directory is empty using storage-aware method + tutorial_data_dir_path = await get_tutorial_data_dir() + tutorial_data_dir = Path(tutorial_data_dir_path) + if tutorial_data_dir.exists(): + assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty" + + # First creation (without force refresh) + await _create_tutorial_notebook(user_id, mock_session, force_refresh=False) + + items_first = list(tutorial_data_dir.iterdir()) + assert len(items_first) == 1, ( + "Tutorial data directory should contain exactly one item after first creation" + ) + first_dir = items_first[0] + assert first_dir.is_dir(), "Tutorial data directory item should be a directory" + + # Verify the structure inside the tutorial directory (first creation) + notebook_file = first_dir / "tutorial.ipynb" + assert notebook_file.exists(), f"tutorial.ipynb should exist in {first_dir}" + assert notebook_file.is_file(), "tutorial.ipynb should be a file" + + data_folder = first_dir / "data" + assert data_folder.exists(), f"data subfolder should exist in {first_dir}" + assert data_folder.is_dir(), "data should be a directory" + + data_items = list(data_folder.iterdir()) + assert len(data_items) > 0, ( + f"data folder should contain files, but found {len(data_items)} items" + ) + + # Capture metadata from first creation + + first_creation_metadata = {} + + for file_path in first_dir.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(first_dir) + stat = file_path.stat() + + # Store multiple metadata points + with open(file_path, "rb") as f: + content = f.read() + + first_creation_metadata[str(relative_path)] = { + "mtime": stat.st_mtime, + "size": stat.st_size, + "hash": hashlib.md5(content).hexdigest(), + "first_bytes": content[:100] + if content + else b"", # First 100 bytes as fingerprint + } + + # Wait a moment to ensure different timestamps + time.sleep(0.1) + + # Force refresh - should create new files with different metadata + await _create_tutorial_notebook(user_id, mock_session, force_refresh=True) + + items_second = list(tutorial_data_dir.iterdir()) + assert len(items_second) == 1, ( + "Tutorial data directory should contain exactly one item after force refresh" + ) + second_dir = items_second[0] + + # Verify the structure is maintained after force refresh + notebook_file_second = second_dir / "tutorial.ipynb" + assert notebook_file_second.exists(), ( + f"tutorial.ipynb should exist in {second_dir} after force refresh" + ) + assert notebook_file_second.is_file(), "tutorial.ipynb should be a file after force refresh" + + data_folder_second = second_dir / "data" + assert data_folder_second.exists(), ( + f"data subfolder should exist in {second_dir} after force refresh" + ) + assert data_folder_second.is_dir(), "data should be a directory after force refresh" + + data_items_second = list(data_folder_second.iterdir()) + assert len(data_items_second) > 0, ( + f"data folder should still contain files after force refresh, but found {len(data_items_second)} items" + ) + + # Compare metadata to ensure files are actually different + files_with_changed_metadata = 0 + + for file_path in second_dir.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(second_dir) + relative_path_str = str(relative_path) + + # File should exist from first creation + assert relative_path_str in first_creation_metadata, ( + f"File {relative_path_str} missing from first creation" + ) + + old_metadata = first_creation_metadata[relative_path_str] + + # Get new metadata + stat = file_path.stat() + with open(file_path, "rb") as f: + new_content = f.read() + + new_metadata = { + "mtime": stat.st_mtime, + "size": stat.st_size, + "hash": hashlib.md5(new_content).hexdigest(), + "first_bytes": new_content[:100] if new_content else b"", + } + + # Check if any metadata changed (indicating file was refreshed) + metadata_changed = ( + new_metadata["mtime"] > old_metadata["mtime"] # Newer modification time + or new_metadata["hash"] != old_metadata["hash"] # Different content hash + or new_metadata["size"] != old_metadata["size"] # Different file size + or new_metadata["first_bytes"] + != old_metadata["first_bytes"] # Different content + ) + + if metadata_changed: + files_with_changed_metadata += 1 + + # Assert that force refresh actually updated files + assert files_with_changed_metadata > 0, ( + f"Force refresh should have updated at least some files, but all {len(first_creation_metadata)} " + f"files appear to have identical metadata. This suggests force refresh didn't work." + ) + + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_tutorial_zip_url_accessibility(self): + """Test that the actual tutorial zip URL is accessible (integration test).""" + try: + import requests + + response = requests.get( + "https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip", + timeout=10, + ) + response.raise_for_status() + + # Verify it's a valid zip file by checking headers + assert response.headers.get("content-type") in [ + "application/zip", + "application/octet-stream", + "application/x-zip-compressed", + ] or response.content.startswith(b"PK") # Zip file signature + + except Exception: + pytest.skip("Network request failed or zip not available - skipping integration test") diff --git a/pyproject.toml b/pyproject.toml index 89773a74d..2722e8944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "matplotlib>=3.8.3,<4", "networkx>=3.4.2,<4", "lancedb>=0.24.0,<1.0.0", + "nbformat>=5.7.0,<6.0.0", "alembic>=1.13.3,<2", "pre-commit>=4.0.1,<5", "scikit-learn>=1.6.1,<2", diff --git a/uv.lock b/uv.lock index 670653862..6a187e0a4 100644 --- a/uv.lock +++ b/uv.lock @@ -831,6 +831,7 @@ dependencies = [ { name = "limits" }, { name = "litellm" }, { name = "matplotlib" }, + { name = "nbformat" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "nltk" }, @@ -1012,6 +1013,7 @@ requires-dist = [ { name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" }, { name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" }, + { name = "nbformat", specifier = ">=5.7.0,<6.0.0" }, { name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" }, { name = "networkx", specifier = ">=3.4.2,<4" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" },