feat: add welcome tutorial notebook for new users (#1425)

<!-- .github/pull_request_template.md -->

## Description
<!-- 
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

Update default tutorial:
1. Use tutorial from [notebook_tutorial
branch](https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/tutorial.ipynb),
specifically - it's .zip version with all necessary data files
2. Use Jupyter Notebook `Notebook` abstractions to read, and map `ipynb`
into our Notebook model
3. Dynamically update starter notebook code blocks that reference
starter data files, and swap them with local paths to downloaded copies
4. Test coverage



| Before | After (storage backend = local) | After (s3) |
|--------|---------------------------------|------------|
| <img width="613" height="546" alt="Screenshot 2025-09-17 at 01 00 58"
src="https://github.com/user-attachments/assets/20b59021-96c1-4a83-977f-e064324bd758"
/> | <img width="1480" height="262" alt="Screenshot 2025-09-18 at 13 01
57"
src="https://github.com/user-attachments/assets/bd56ea78-7c6a-42e3-ae3f-4157da231b2d"
/> | <img width="1485" height="307" alt="Screenshot 2025-09-18 at 12 56
08"
src="https://github.com/user-attachments/assets/248ae720-4c78-445a-ba8b-8a2991ed3f80"
/> |



## File Replacements

### S3 Demo  

https://github.com/user-attachments/assets/bd46eec9-ef77-4f69-9ef0-e7d1612ff9b3

---

### Local FS Demo  

https://github.com/user-attachments/assets/8251cea0-81b3-4cac-a968-9576c358f334


## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Changes Made
<!-- List the specific changes made in this PR -->
- 
- 
- 

## Testing
<!-- Describe how you tested your changes -->

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
<!-- Link any related issues using "Fixes #issue_number" or "Relates to
#issue_number" -->

## Additional Notes
<!-- Add any additional notes, concerns, or context for reviewers -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Daulet Amirkhanov 2025-09-18 17:07:05 +01:00 committed by GitHub
parent bb124494c1
commit f58ba86e7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1200 additions and 44 deletions

View file

@ -47,6 +47,28 @@ BAML_LLM_API_VERSION=""
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/' # DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/' # SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
################################################################################
# ☁️ Storage Backend Settings
################################################################################
# Configure storage backend (local filesystem or S3)
# STORAGE_BACKEND="local" # Default: uses local filesystem
#
# -- To switch to S3 storage, uncomment and fill these: ---------------------
# STORAGE_BACKEND="s3"
# STORAGE_BUCKET_NAME="your-bucket-name"
# AWS_REGION="us-east-1"
# AWS_ACCESS_KEY_ID="your-access-key"
# AWS_SECRET_ACCESS_KEY="your-secret-key"
#
# -- S3 Root Directories (optional) -----------------------------------------
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
#
# -- Cache Directory (auto-configured for S3) -------------------------------
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
# To override the automatic S3 cache location, uncomment:
# CACHE_ROOT_DIRECTORY="s3://your-bucket/cognee/cache"
################################################################################ ################################################################################
# 🗄️ Relational database settings # 🗄️ Relational database settings
################################################################################ ################################################################################

1
.gitignore vendored
View file

@ -186,6 +186,7 @@ cognee/cache/
# Default cognee system directory, used in development # Default cognee system directory, used in development
.cognee_system/ .cognee_system/
.data_storage/ .data_storage/
.cognee_cache/
.artifacts/ .artifacts/
.anon_id .anon_id

View file

@ -7,8 +7,8 @@ class prune:
await _prune_data() await _prune_data()
@staticmethod @staticmethod
async def prune_system(graph=True, vector=True, metadata=False): async def prune_system(graph=True, vector=True, metadata=False, cache=True):
await _prune_system(graph, vector, metadata) await _prune_system(graph, vector, metadata, cache)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
mark_sync_completed, mark_sync_completed,
mark_sync_failed, mark_sync_failed,
) )
from cognee.shared.utils import create_secure_ssl_context
logger = get_logger("sync") logger = get_logger("sync")
@ -583,7 +584,9 @@ async def _check_hashes_diff(
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}") logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
try: try:
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(url, json=payload.dict(), headers=headers) as response: async with session.post(url, json=payload.dict(), headers=headers) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
@ -630,7 +633,9 @@ async def _download_missing_files(
headers = {"X-Api-Key": auth_token} headers = {"X-Api-Key": auth_token}
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
for file_hash in hashes_missing_on_local: for file_hash in hashes_missing_on_local:
try: try:
# Download file from cloud by hash # Download file from cloud by hash
@ -749,7 +754,9 @@ async def _upload_missing_files(
headers = {"X-Api-Key": auth_token} headers = {"X-Api-Key": auth_token}
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
for file_info in files_to_upload: for file_info in files_to_upload:
try: try:
file_dir = os.path.dirname(file_info.raw_data_location) file_dir = os.path.dirname(file_info.raw_data_location)
@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
logger.info("Pruning cloud dataset to match local state") logger.info("Pruning cloud dataset to match local state")
try: try:
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.put(url, json=payload.dict(), headers=headers) as response: async with session.put(url, json=payload.dict(), headers=headers) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
logger.info(f"Triggering cognify processing for dataset {dataset_id}") logger.info(f"Triggering cognify processing for dataset {dataset_id}")
try: try:
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(url, json=payload, headers=headers) as response: async with session.post(url, json=payload, headers=headers) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()

View file

@ -10,13 +10,27 @@ import pydantic
class BaseConfig(BaseSettings): class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage") data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system") system_root_directory: str = get_absolute_path(".cognee_system")
cache_root_directory: str = get_absolute_path(".cognee_cache")
monitoring_tool: object = Observer.LANGFUSE monitoring_tool: object = Observer.LANGFUSE
@pydantic.model_validator(mode="after") @pydantic.model_validator(mode="after")
def validate_paths(self): def validate_paths(self):
# Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
# I'll remove this after we update documentation for S3 storage
# Auto-configure cache root directory for S3 storage if not explicitly set
storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
if storage_backend == "s3" and not cache_root_env:
# Auto-generate S3 cache path when using S3 storage
bucket_name = os.getenv("STORAGE_BUCKET_NAME")
if bucket_name:
self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
# Require absolute paths for root directories # Require absolute paths for root directories
self.data_root_directory = ensure_absolute_path(self.data_root_directory) self.data_root_directory = ensure_absolute_path(self.data_root_directory)
self.system_root_directory = ensure_absolute_path(self.system_root_directory) self.system_root_directory = ensure_absolute_path(self.system_root_directory)
self.cache_root_directory = ensure_absolute_path(self.cache_root_directory)
return self return self
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
@ -31,6 +45,7 @@ class BaseConfig(BaseSettings):
"data_root_directory": self.data_root_directory, "data_root_directory": self.data_root_directory,
"system_root_directory": self.system_root_directory, "system_root_directory": self.system_root_directory,
"monitoring_tool": self.monitoring_tool, "monitoring_tool": self.monitoring_tool,
"cache_root_directory": self.cache_root_directory,
} }

View file

@ -7,6 +7,7 @@ import aiohttp
from uuid import UUID from uuid import UUID
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
from cognee.shared.utils import create_secure_ssl_context
logger = get_logger() logger = get_logger()
@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
async def _get_session(self) -> aiohttp.ClientSession: async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session.""" """Get or create an aiohttp session."""
if self._session is None or self._session.closed: if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession() ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
self._session = aiohttp.ClientSession(connector=connector)
return self._session return self._session
async def close(self): async def close(self):

View file

@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
embedding_rate_limit_async, embedding_rate_limit_async,
embedding_sleep_and_retry_async, embedding_sleep_and_retry_async,
) )
from cognee.shared.utils import create_secure_ssl_context
logger = get_logger("OllamaEmbeddingEngine") logger = get_logger("OllamaEmbeddingEngine")
@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post( async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0 self.endpoint, json=payload, headers=headers, timeout=60.0
) as response: ) as response:

View file

@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
if os.path.exists(full_file_path): if os.path.exists(full_file_path):
os.remove(full_file_path) os.remove(full_file_path)
def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
"""
List all files in the specified directory.
Parameters:
-----------
- directory_path (str): The directory path to list files from
- recursive (bool): If True, list files recursively in subdirectories
Returns:
--------
- list[str]: List of file paths relative to the storage root
"""
from pathlib import Path
parsed_storage_path = get_parsed_path(self.storage_path)
if directory_path:
full_directory_path = os.path.join(parsed_storage_path, directory_path)
else:
full_directory_path = parsed_storage_path
directory_pathlib = Path(full_directory_path)
if not directory_pathlib.exists() or not directory_pathlib.is_dir():
return []
files = []
if recursive:
# Use rglob for recursive search
for file_path in directory_pathlib.rglob("*"):
if file_path.is_file():
# Get relative path from storage root
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
# Normalize path separators for consistency
relative_path = relative_path.replace(os.sep, "/")
files.append(relative_path)
else:
# Use iterdir for just immediate directory
for file_path in directory_pathlib.iterdir():
if file_path.is_file():
# Get relative path from storage root
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
# Normalize path separators for consistency
relative_path = relative_path.replace(os.sep, "/")
files.append(relative_path)
return files
def remove_all(self, tree_path: str = None): def remove_all(self, tree_path: str = None):
""" """
Remove an entire directory tree at the specified path, including all files and Remove an entire directory tree at the specified path, including all files and

View file

@ -155,21 +155,19 @@ class S3FileStorage(Storage):
""" """
Ensure that the specified directory exists, creating it if necessary. Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken. For S3 storage, this is a no-op since directories are created implicitly
when files are written to paths. S3 doesn't have actual directories,
just object keys with prefixes that appear as directories.
Parameters: Parameters:
----------- -----------
- directory_path (str): The path of the directory to check or create. - directory_path (str): The path of the directory to check or create.
""" """
if not directory_path.strip(): # In S3, directories don't exist as separate entities - they're just prefixes
directory_path = self.storage_path.replace("s3://", "") # When you write a file to s3://bucket/path/to/file.txt, the "directories"
# path/ and path/to/ are implicitly created. No explicit action needed.
def ensure_directory(): pass
if not self.s3.exists(directory_path):
self.s3.makedirs(directory_path, exist_ok=True)
await run_async(ensure_directory)
async def copy_file(self, source_file_path: str, destination_file_path: str): async def copy_file(self, source_file_path: str, destination_file_path: str):
""" """
@ -213,6 +211,55 @@ class S3FileStorage(Storage):
await run_async(remove_file) await run_async(remove_file)
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
"""
List all files in the specified directory.
Parameters:
-----------
- directory_path (str): The directory path to list files from
- recursive (bool): If True, list files recursively in subdirectories
Returns:
--------
- list[str]: List of file paths relative to the storage root
"""
def list_files_sync():
if directory_path:
# Combine storage path with directory path
full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
else:
full_path = self.storage_path.replace("s3://", "")
if recursive:
# Use ** for recursive search
pattern = f"{full_path}/**"
else:
# Just files in the immediate directory
pattern = f"{full_path}/*"
# Use s3fs glob to find files
try:
all_paths = self.s3.glob(pattern)
# Filter to only files (not directories)
files = [path for path in all_paths if self.s3.isfile(path)]
# Convert back to relative paths from storage root
storage_prefix = self.storage_path.replace("s3://", "")
relative_files = []
for file_path in files:
if file_path.startswith(storage_prefix):
relative_path = file_path[len(storage_prefix) :].lstrip("/")
relative_files.append(relative_path)
return relative_files
except Exception:
# If directory doesn't exist or other error, return empty list
return []
return await run_async(list_files_sync)
async def remove_all(self, tree_path: str): async def remove_all(self, tree_path: str):
""" """
Remove an entire directory tree at the specified path, including all files and Remove an entire directory tree at the specified path, including all files and

View file

@ -135,6 +135,24 @@ class StorageManager:
else: else:
return self.storage.remove(file_path) return self.storage.remove(file_path)
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
"""
List all files in the specified directory.
Parameters:
-----------
- directory_path (str): The directory path to list files from
- recursive (bool): If True, list files recursively in subdirectories
Returns:
--------
- list[str]: List of file paths relative to the storage root
"""
if inspect.iscoroutinefunction(self.storage.list_files):
return await self.storage.list_files(directory_path, recursive)
else:
return self.storage.list_files(directory_path, recursive)
async def remove_all(self, tree_path: str = None): async def remove_all(self, tree_path: str = None):
""" """
Remove an entire directory tree at the specified path, including all files and Remove an entire directory tree at the specified path, including all files and

View file

@ -1,6 +1,7 @@
import aiohttp import aiohttp
from cognee.modules.cloud.exceptions import CloudConnectionError from cognee.modules.cloud.exceptions import CloudConnectionError
from cognee.shared.utils import create_secure_ssl_context
async def check_api_key(auth_token: str): async def check_api_key(auth_token: str):
@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
headers = {"X-Api-Key": auth_token} headers = {"X-Api-Key": auth_token}
try: try:
async with aiohttp.ClientSession() as session: ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(url, headers=headers) as response: async with session.post(url, headers=headers) as response:
if response.status == 200: if response.status == 200:
return return

View file

@ -1,9 +1,10 @@
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.shared.cache import delete_cache
async def prune_system(graph=True, vector=True, metadata=True): async def prune_system(graph=True, vector=True, metadata=True, cache=True):
if graph: if graph:
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
await graph_engine.delete_graph() await graph_engine.delete_graph()
@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
if metadata: if metadata:
db_engine = get_relational_engine() db_engine = get_relational_engine()
await db_engine.delete_database() await db_engine.delete_database()
if cache:
await delete_cache()

View file

@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session
from ..models.Notebook import Notebook, NotebookCell from ..models.Notebook import Notebook, NotebookCell
async def _create_tutorial_notebook(
user_id: UUID, session: AsyncSession, force_refresh: bool = False
) -> None:
"""
Create the default tutorial notebook for new users.
Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
"""
TUTORIAL_ZIP_URL = (
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
)
try:
# Create notebook from remote zip file (includes notebook + data files)
notebook = await Notebook.from_ipynb_zip_url(
zip_url=TUTORIAL_ZIP_URL,
owner_id=user_id,
notebook_filename="tutorial.ipynb",
name="Python Development with Cognee Tutorial 🧠",
deletable=False,
force=force_refresh,
)
# Add to session and commit
session.add(notebook)
await session.commit()
except Exception as e:
print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
raise e
@with_async_session @with_async_session
async def create_notebook( async def create_notebook(
user_id: UUID, user_id: UUID,

View file

@ -1,13 +1,24 @@
import json import json
from typing import List, Literal import nbformat
import asyncio
from nbformat.notebooknode import NotebookNode
from typing import List, Literal, Optional, cast, Tuple
from uuid import uuid4, UUID as UUID_t from uuid import uuid4, UUID as UUID_t
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from datetime import datetime, timezone from datetime import datetime, timezone
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from pathlib import Path
from cognee.infrastructure.databases.relational import Base from cognee.infrastructure.databases.relational import Base
from cognee.shared.cache import (
download_and_extract_zip,
get_tutorial_data_dir,
generate_content_hash,
)
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
from cognee.base_config import get_base_config
class NotebookCell(BaseModel): class NotebookCell(BaseModel):
@ -51,3 +62,197 @@ class Notebook(Base):
deletable: Mapped[bool] = mapped_column(Boolean, default=True) deletable: Mapped[bool] = mapped_column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
@classmethod
async def from_ipynb_zip_url(
cls,
zip_url: str,
owner_id: UUID_t,
notebook_filename: str = "tutorial.ipynb",
name: Optional[str] = None,
deletable: bool = True,
force: bool = False,
) -> "Notebook":
"""
Create a Notebook instance from a remote zip file containing notebook + data files.
Args:
zip_url: Remote URL to fetch the .zip file from
owner_id: UUID of the notebook owner
notebook_filename: Name of the .ipynb file within the zip
name: Optional custom name for the notebook
deletable: Whether the notebook can be deleted
force: If True, re-download even if already cached
Returns:
Notebook instance
"""
# Generate a cache key based on the zip URL
content_hash = generate_content_hash(zip_url, notebook_filename)
# Download and extract the zip file to tutorial_data/{content_hash}
try:
extracted_cache_dir = await download_and_extract_zip(
url=zip_url,
cache_dir_name=f"tutorial_data/{content_hash}",
version_or_hash=content_hash,
force=force,
)
except Exception as e:
raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
# Use cache system to access the notebook file
from cognee.shared.cache import cache_file_exists, read_cache_file
notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
# Check if the notebook file exists in cache
if not await cache_file_exists(notebook_file_path):
raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
# Read and parse the notebook using cache system
async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
notebook_content = await asyncio.to_thread(f.read)
notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
# Update file paths in notebook cells to point to actual cached data files
await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
return notebook
@staticmethod
async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
"""
Update file paths in code cells to use actual cached data files.
Works with both local filesystem and S3 storage.
Args:
notebook: Parsed Notebook instance with cells to update
cache_dir: Path to the cached tutorial directory containing data files
"""
import re
from cognee.shared.cache import list_cache_files, cache_file_exists
from cognee.shared.logging_utils import get_logger
logger = get_logger()
# Look for data files in the data subdirectory
data_dir = f"{cache_dir}/data"
try:
# Get all data files in the cache directory using cache system
data_files = {}
if await cache_file_exists(data_dir):
file_list = await list_cache_files(data_dir)
else:
file_list = []
for file_path in file_list:
# Extract just the filename
filename = file_path.split("/")[-1]
# Use the file path as provided by cache system
data_files[filename] = file_path
except Exception as e:
# If we can't list files, skip updating paths
logger.error(f"Error listing data files in {data_dir}: {e}")
return
# Pattern to match file://data/filename patterns in code cells
file_pattern = r'"file://data/([^"]+)"'
def replace_path(match):
filename = match.group(1)
if filename in data_files:
file_path = data_files[filename]
# For local filesystem, preserve file:// prefix
if not file_path.startswith("s3://"):
return f'"file://{file_path}"'
else:
# For S3, return the S3 URL as-is
return f'"{file_path}"'
return match.group(0) # Keep original if file not found
# Update only code cells
updated_cells = 0
for cell in notebook.cells:
if cell.type == "code":
original_content = cell.content
# Update file paths in the cell content
cell.content = re.sub(file_pattern, replace_path, cell.content)
if original_content != cell.content:
updated_cells += 1
# Log summary of updates (useful for monitoring)
if updated_cells > 0:
logger.info(f"Updated file paths in {updated_cells} notebook cells")
@classmethod
def from_ipynb_string(
cls,
notebook_content: str,
owner_id: UUID_t,
name: Optional[str] = None,
deletable: bool = True,
) -> "Notebook":
"""
Create a Notebook instance from Jupyter notebook string content.
Args:
notebook_content: Raw Jupyter notebook content as string
owner_id: UUID of the notebook owner
name: Optional custom name for the notebook
deletable: Whether the notebook can be deleted
Returns:
Notebook instance ready to be saved to database
"""
# Parse and validate the Jupyter notebook using nbformat
# Note: nbformat.reads() has loose typing, so we cast to NotebookNode
jupyter_nb = cast(
NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
)
# Convert Jupyter cells to NotebookCell objects
cells = []
for jupyter_cell in jupyter_nb.cells:
# Each cell is also a NotebookNode with dynamic attributes
cell = cast(NotebookNode, jupyter_cell)
# Skip raw cells as they're not supported in our model
if cell.cell_type == "raw":
continue
# Get the source content
content = cell.source
# Generate a name based on content or cell index
cell_name = cls._generate_cell_name(cell)
# Map cell types (jupyter uses "code"/"markdown", we use same)
cell_type = "code" if cell.cell_type == "code" else "markdown"
cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
# Extract notebook name from metadata if not provided
if name is None:
kernelspec = jupyter_nb.metadata.get("kernelspec", {})
name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
@staticmethod
def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
"""Generate a meaningful name for a notebook cell using nbformat cell."""
if jupyter_cell.cell_type == "markdown":
# Try to extract a title from markdown headers
content = jupyter_cell.source
lines = content.strip().split("\n")
if lines and lines[0].startswith("#"):
# Extract header text, clean it up
header = lines[0].lstrip("#").strip()
return header[:50] if len(header) > 50 else header
else:
return "Markdown Cell"
else:
return "Code Cell"

View file

@ -1,9 +1,10 @@
from uuid import uuid4 from uuid import UUID, uuid4
from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.exceptions import UserAlreadyExists
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.notebooks.methods import create_notebook from cognee.modules.notebooks.models.Notebook import Notebook
from cognee.modules.notebooks.models.Notebook import NotebookCell from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
from cognee.modules.users.exceptions import TenantNotFoundError from cognee.modules.users.exceptions import TenantNotFoundError
from cognee.modules.users.get_user_manager import get_user_manager_context from cognee.modules.users.get_user_manager import get_user_manager_context
from cognee.modules.users.get_user_db import get_user_db_context from cognee.modules.users.get_user_db import get_user_db_context
@ -60,26 +61,7 @@ async def create_user(
if auto_login: if auto_login:
await session.refresh(user) await session.refresh(user)
await create_notebook( await _create_tutorial_notebook(user.id, session)
user_id=user.id,
notebook_name="Welcome to cognee 🧠",
cells=[
NotebookCell(
id=uuid4(),
name="Welcome",
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
type="markdown",
),
NotebookCell(
id=uuid4(),
name="Example",
content="",
type="code",
),
],
deletable=False,
session=session,
)
return user return user
except UserAlreadyExists as error: except UserAlreadyExists as error:

346
cognee/shared/cache.py Normal file
View file

@ -0,0 +1,346 @@
"""
Storage-aware cache management utilities for Cognee.
This module provides cache functionality that works with both local and cloud storage
backends (like S3) through the StorageManager abstraction.
"""
import hashlib
import zipfile
import asyncio
from typing import Optional, Tuple
import aiohttp
import logging
from io import BytesIO
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
from cognee.infrastructure.files.storage.StorageManager import StorageManager
from cognee.shared.utils import create_secure_ssl_context
logger = logging.getLogger(__name__)
class StorageAwareCache:
"""
A cache manager that works with different storage backends (local, S3, etc.)
"""
def __init__(self, cache_subdir: str = "cache"):
"""
Initialize the cache manager.
Args:
cache_subdir: Subdirectory name within the system root for caching
"""
self.base_config = get_base_config()
# Since we're using cache_root_directory, don't add extra cache prefix
self.cache_base_path = ""
self.storage_manager: StorageManager = get_file_storage(
self.base_config.cache_root_directory
)
# Print absolute path
storage_path = self.storage_manager.storage.storage_path
if storage_path.startswith("s3://"):
absolute_path = storage_path # S3 paths are already absolute
else:
import os
absolute_path = os.path.abspath(storage_path)
logger.info(f"Storage manager absolute path: {absolute_path}")
async def get_cache_dir(self) -> str:
"""Get the base cache directory path."""
cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty
await self.storage_manager.ensure_directory_exists(cache_path)
return cache_path
async def get_cache_subdir(self, name: str) -> str:
"""Get a specific cache subdirectory."""
if self.cache_base_path:
cache_path = f"{self.cache_base_path}/{name}"
else:
cache_path = name
await self.storage_manager.ensure_directory_exists(cache_path)
# Return the absolute path based on storage system
if self.storage_manager.storage.storage_path.startswith("s3://"):
return cache_path
elif hasattr(self.storage_manager.storage, "storage_path"):
return f"{self.storage_manager.storage.storage_path}/{cache_path}"
else:
# Fallback for other storage types
return cache_path
async def delete_cache(self):
"""Delete the entire cache directory."""
logger.info("Deleting cache...")
try:
await self.storage_manager.remove_all(self.cache_base_path)
logger.info("✓ Cache deleted successfully!")
except Exception as e:
logger.error(f"Error deleting cache: {e}")
raise
async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool:
"""Check if cached content is valid for the given version/hash."""
version_file = f"{cache_dir}/version.txt"
if not await self.storage_manager.file_exists(version_file):
return False
try:
async with self.storage_manager.open(version_file, "r") as f:
cached_version = (await asyncio.to_thread(f.read)).strip()
return cached_version == version_or_hash
except Exception as e:
logger.debug(f"Error checking cache validity: {e}")
return False
async def _clear_cache(self, cache_dir: str) -> None:
"""Clear a cache directory."""
try:
await self.storage_manager.remove_all(cache_dir)
except Exception as e:
logger.debug(f"Error clearing cache directory {cache_dir}: {e}")
async def _check_remote_content_freshness(
self, url: str, cache_dir: str
) -> Tuple[bool, Optional[str]]:
"""
Check if remote content is fresher than cached version using HTTP headers.
Returns:
Tuple of (is_fresh: bool, new_identifier: Optional[str])
"""
try:
# Make a HEAD request to check headers without downloading
ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
response.raise_for_status()
# Try ETag first (most reliable)
etag = response.headers.get("ETag", "").strip('"')
last_modified = response.headers.get("Last-Modified", "")
# Use ETag if available, otherwise Last-Modified
remote_identifier = etag if etag else last_modified
if not remote_identifier:
logger.debug("No freshness headers available, cannot check for updates")
return True, None # Assume fresh if no headers
# Check cached identifier
identifier_file = f"{cache_dir}/content_id.txt"
if await self.storage_manager.file_exists(identifier_file):
async with self.storage_manager.open(identifier_file, "r") as f:
cached_identifier = (await asyncio.to_thread(f.read)).strip()
if cached_identifier == remote_identifier:
logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)")
return True, None
else:
logger.info(
f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)"
)
return False, remote_identifier
else:
# No cached identifier, treat as stale
return False, remote_identifier
except Exception as e:
logger.debug(f"Could not check remote freshness: {e}")
return True, None # Assume fresh if we can't check
async def download_and_extract_zip(
self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False
) -> str:
"""
Download a zip file and extract it to cache directory with content freshness checking.
Args:
url: URL to download zip file from
cache_subdir_name: Name of the cache subdirectory
version_or_hash: Version string or content hash for cache validation
force: If True, re-download even if already cached
Returns:
Path to the cached directory
"""
cache_dir = await self.get_cache_subdir(cache_subdir_name)
# Check if already cached and valid
if not force and await self._is_cache_valid(cache_dir, version_or_hash):
# Also check if remote content has changed
is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir)
if is_fresh:
logger.debug(f"Content already cached and fresh for version {version_or_hash}")
return cache_dir
else:
logger.info("Cached content is stale, updating...")
# Clear old cache if it exists
await self._clear_cache(cache_dir)
logger.info(f"Downloading content from {url}...")
# Download the zip file
zip_content = BytesIO()
etag = ""
last_modified = ""
ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response:
response.raise_for_status()
# Extract headers before consuming response
etag = response.headers.get("ETag", "").strip('"')
last_modified = response.headers.get("Last-Modified", "")
# Read the response content
async for chunk in response.content.iter_chunked(8192):
zip_content.write(chunk)
zip_content.seek(0)
# Extract the archive
await self.storage_manager.ensure_directory_exists(cache_dir)
# Extract files and store them using StorageManager
with zipfile.ZipFile(zip_content, "r") as zip_file:
for file_info in zip_file.infolist():
if file_info.is_dir():
# Create directory
dir_path = f"{cache_dir}/{file_info.filename}"
await self.storage_manager.ensure_directory_exists(dir_path)
else:
# Extract and store file
file_data = zip_file.read(file_info.filename)
file_path = f"{cache_dir}/{file_info.filename}"
await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True)
# Write version info for future cache validation
version_file = f"{cache_dir}/version.txt"
await self.storage_manager.store(version_file, version_or_hash, overwrite=True)
# Store content identifier from response headers for freshness checking
content_identifier = etag if etag else last_modified
if content_identifier:
identifier_file = f"{cache_dir}/content_id.txt"
await self.storage_manager.store(identifier_file, content_identifier, overwrite=True)
logger.debug(f"Stored content identifier: {content_identifier[:20]}...")
logger.info("✓ Content downloaded and cached successfully!")
return cache_dir
async def file_exists(self, file_path: str) -> bool:
"""Check if a file exists in cache storage."""
return await self.storage_manager.file_exists(file_path)
async def read_file(self, file_path: str, encoding: str = "utf-8"):
"""Read a file from cache storage."""
return self.storage_manager.open(file_path, encoding=encoding)
async def list_files(self, directory_path: str):
"""List files in a cache directory."""
try:
file_list = await self.storage_manager.list_files(directory_path)
# For S3 storage, convert relative paths to full S3 URLs
if self.storage_manager.storage.storage_path.startswith("s3://"):
full_paths = []
for file_path in file_list:
full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}"
full_paths.append(full_s3_path)
return full_paths
else:
# For local storage, return absolute paths
storage_path = self.storage_manager.storage.storage_path
if not storage_path.startswith("/"):
import os
storage_path = os.path.abspath(storage_path)
full_paths = []
for file_path in file_list:
if file_path.startswith("/"):
full_paths.append(file_path) # Already absolute
else:
full_paths.append(f"{storage_path}/{file_path}")
return full_paths
except Exception as e:
logger.debug(f"Error listing files in {directory_path}: {e}")
return []
# Convenience functions that maintain API compatibility
_cache_manager = None
def get_cache_manager() -> StorageAwareCache:
"""Get a singleton cache manager instance."""
global _cache_manager
if _cache_manager is None:
_cache_manager = StorageAwareCache()
return _cache_manager
def generate_content_hash(url: str, additional_data: str = "") -> str:
"""Generate a content hash from URL and optional additional data."""
content = f"{url}:{additional_data}"
return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability
# Async wrapper functions for backward compatibility
async def delete_cache():
"""Delete the Cognee cache directory."""
cache_manager = get_cache_manager()
await cache_manager.delete_cache()
async def get_cognee_cache_dir() -> str:
"""Get the base Cognee cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.get_cache_dir()
async def get_cache_subdir(name: str) -> str:
"""Get a specific cache subdirectory."""
cache_manager = get_cache_manager()
return await cache_manager.get_cache_subdir(name)
async def download_and_extract_zip(
url: str, cache_dir_name: str, version_or_hash: str, force: bool = False
) -> str:
"""Download a zip file and extract it to cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force)
async def get_tutorial_data_dir() -> str:
"""Get the tutorial data cache directory."""
return await get_cache_subdir("tutorial_data")
# Cache file operations
async def cache_file_exists(file_path: str) -> bool:
"""Check if a file exists in cache storage."""
cache_manager = get_cache_manager()
return await cache_manager.file_exists(file_path)
async def read_cache_file(file_path: str, encoding: str = "utf-8"):
"""Read a file from cache storage."""
cache_manager = get_cache_manager()
return await cache_manager.read_file(file_path, encoding)
async def list_cache_files(directory_path: str):
"""List files in a cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.list_files(directory_path)

View file

@ -1,6 +1,7 @@
"""This module contains utility functions for the cognee.""" """This module contains utility functions for the cognee."""
import os import os
import ssl
import requests import requests
from datetime import datetime, timezone from datetime import datetime, timezone
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -18,6 +19,17 @@ from cognee.infrastructure.databases.graph import get_graph_engine
proxy_url = "https://test.prometh.ai" proxy_url = "https://test.prometh.ai"
def create_secure_ssl_context() -> ssl.SSLContext:
"""
Create a secure SSL context.
By default, use the system's certificate store.
If users report SSL issues, I'm keeping this open in case we need to switch to:
ssl.create_default_context(cafile=certifi.where())
"""
return ssl.create_default_context()
def get_entities(tagged_tokens): def get_entities(tagged_tokens):
import nltk import nltk

View file

@ -0,0 +1,399 @@
import json
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import hashlib
import time
from uuid import uuid4
from sqlalchemy.ext.asyncio import AsyncSession
from pathlib import Path
import zipfile
from cognee.shared.cache import get_tutorial_data_dir
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
from cognee.modules.notebooks.models.Notebook import Notebook
import cognee
from cognee.shared.logging_utils import get_logger
logger = get_logger()
# Module-level fixtures available to all test classes
@pytest.fixture
def mock_session():
"""Mock database session."""
session = AsyncMock(spec=AsyncSession)
session.add = MagicMock()
session.commit = AsyncMock()
return session
@pytest.fixture
def sample_jupyter_notebook():
"""Sample Jupyter notebook content for testing."""
return {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Tutorial Introduction\n", "\n", "This is a tutorial notebook."],
},
{
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": ["import cognee\n", "print('Hello, Cognee!')"],
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["## Step 1: Data Ingestion\n", "\n", "Let's add some data."],
},
{
"cell_type": "code",
"execution_count": None,
"metadata": {},
"outputs": [],
"source": ["# Add your data here\n", "# await cognee.add('data.txt')"],
},
{
"cell_type": "raw",
"metadata": {},
"source": ["This is a raw cell that should be skipped"],
},
],
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}
},
"nbformat": 4,
"nbformat_minor": 4,
}
class TestTutorialNotebookCreation:
"""Test cases for tutorial notebook creation functionality."""
@pytest.mark.asyncio
async def test_notebook_from_ipynb_string_success(self, sample_jupyter_notebook):
"""Test successful creation of notebook from JSON string."""
notebook_json = json.dumps(sample_jupyter_notebook)
user_id = uuid4()
notebook = Notebook.from_ipynb_string(
notebook_content=notebook_json, owner_id=user_id, name="String Test Notebook"
)
assert notebook.owner_id == user_id
assert notebook.name == "String Test Notebook"
assert len(notebook.cells) == 4 # Should skip the raw cell
assert notebook.cells[0].type == "markdown"
assert notebook.cells[1].type == "code"
@pytest.mark.asyncio
async def test_notebook_cell_name_generation(self, sample_jupyter_notebook):
"""Test that cell names are generated correctly from markdown headers."""
user_id = uuid4()
notebook_json = json.dumps(sample_jupyter_notebook)
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
# Check markdown header extraction
assert notebook.cells[0].name == "Tutorial Introduction"
assert notebook.cells[2].name == "Step 1: Data Ingestion"
# Check code cell naming
assert notebook.cells[1].name == "Code Cell"
assert notebook.cells[3].name == "Code Cell"
@pytest.mark.asyncio
async def test_notebook_from_ipynb_string_with_default_name(self, sample_jupyter_notebook):
"""Test notebook creation uses kernelspec display_name when no name provided."""
user_id = uuid4()
notebook_json = json.dumps(sample_jupyter_notebook)
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
assert notebook.name == "Python 3" # From kernelspec.display_name
@pytest.mark.asyncio
async def test_notebook_from_ipynb_string_fallback_name(self):
"""Test fallback naming when kernelspec is missing."""
minimal_notebook = {
"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Test"]}],
"metadata": {}, # No kernelspec
"nbformat": 4,
"nbformat_minor": 4,
}
user_id = uuid4()
notebook_json = json.dumps(minimal_notebook)
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
assert notebook.name == "Imported Notebook" # Fallback name
@pytest.mark.asyncio
async def test_notebook_from_ipynb_string_invalid_json(self):
"""Test error handling for invalid JSON."""
user_id = uuid4()
invalid_json = "{ invalid json content"
from nbformat.reader import NotJSONError
with pytest.raises(NotJSONError):
Notebook.from_ipynb_string(notebook_content=invalid_json, owner_id=user_id)
@pytest.mark.asyncio
@patch.object(Notebook, "from_ipynb_zip_url")
async def test_create_tutorial_notebook_error_propagated(self, mock_from_zip_url, mock_session):
"""Test that errors are propagated when zip fetch fails."""
user_id = uuid4()
mock_from_zip_url.side_effect = Exception("Network error")
# Should raise the exception (not catch it)
with pytest.raises(Exception, match="Network error"):
await _create_tutorial_notebook(user_id, mock_session)
# Verify error handling path was taken
mock_from_zip_url.assert_called_once()
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
def test_generate_cell_name_code_cell(self):
"""Test cell name generation for code cells."""
from nbformat.notebooknode import NotebookNode
mock_cell = NotebookNode(
{"cell_type": "code", "source": 'import pandas as pd\nprint("Hello world")'}
)
result = Notebook._generate_cell_name(mock_cell)
assert result == "Code Cell"
class TestTutorialNotebookZipFunctionality:
"""Test cases for zip-based tutorial functionality."""
@pytest.mark.asyncio
async def test_notebook_from_ipynb_zip_url_missing_notebook(
self,
):
"""Test error handling when notebook file is missing from zip."""
user_id = uuid4()
with pytest.raises(
FileNotFoundError,
match="Notebook file 'super_random_tutorial_name.ipynb' not found in zip",
):
await Notebook.from_ipynb_zip_url(
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
owner_id=user_id,
notebook_filename="super_random_tutorial_name.ipynb",
)
@pytest.mark.asyncio
async def test_notebook_from_ipynb_zip_url_download_failure(self):
"""Test error handling when zip download fails."""
user_id = uuid4()
with pytest.raises(RuntimeError, match="Failed to download tutorial zip"):
await Notebook.from_ipynb_zip_url(
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/nonexistent_tutorial_name.zip",
owner_id=user_id,
)
@pytest.mark.asyncio
async def test_create_tutorial_notebook_zip_success(self, mock_session):
"""Test successful tutorial notebook creation with zip."""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
user_id = uuid4()
# Check that tutorial data directory is empty using storage-aware method
tutorial_data_dir_path = await get_tutorial_data_dir()
tutorial_data_dir = Path(tutorial_data_dir_path)
if tutorial_data_dir.exists():
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
await _create_tutorial_notebook(user_id, mock_session)
items = list(tutorial_data_dir.iterdir())
assert len(items) == 1, "Tutorial data directory should contain exactly one item"
assert items[0].is_dir(), "Tutorial data directory item should be a directory"
# Verify the structure inside the tutorial directory
tutorial_dir = items[0]
# Check for tutorial.ipynb file
notebook_file = tutorial_dir / "tutorial.ipynb"
assert notebook_file.exists(), f"tutorial.ipynb should exist in {tutorial_dir}"
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
# Check for data subfolder with contents
data_folder = tutorial_dir / "data"
assert data_folder.exists(), f"data subfolder should exist in {tutorial_dir}"
assert data_folder.is_dir(), "data should be a directory"
data_items = list(data_folder.iterdir())
assert len(data_items) > 0, (
f"data folder should contain files, but found {len(data_items)} items"
)
@pytest.mark.asyncio
async def test_create_tutorial_notebook_with_force_refresh(self, mock_session):
"""Test tutorial notebook creation with force refresh."""
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
user_id = uuid4()
# Check that tutorial data directory is empty using storage-aware method
tutorial_data_dir_path = await get_tutorial_data_dir()
tutorial_data_dir = Path(tutorial_data_dir_path)
if tutorial_data_dir.exists():
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
# First creation (without force refresh)
await _create_tutorial_notebook(user_id, mock_session, force_refresh=False)
items_first = list(tutorial_data_dir.iterdir())
assert len(items_first) == 1, (
"Tutorial data directory should contain exactly one item after first creation"
)
first_dir = items_first[0]
assert first_dir.is_dir(), "Tutorial data directory item should be a directory"
# Verify the structure inside the tutorial directory (first creation)
notebook_file = first_dir / "tutorial.ipynb"
assert notebook_file.exists(), f"tutorial.ipynb should exist in {first_dir}"
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
data_folder = first_dir / "data"
assert data_folder.exists(), f"data subfolder should exist in {first_dir}"
assert data_folder.is_dir(), "data should be a directory"
data_items = list(data_folder.iterdir())
assert len(data_items) > 0, (
f"data folder should contain files, but found {len(data_items)} items"
)
# Capture metadata from first creation
first_creation_metadata = {}
for file_path in first_dir.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(first_dir)
stat = file_path.stat()
# Store multiple metadata points
with open(file_path, "rb") as f:
content = f.read()
first_creation_metadata[str(relative_path)] = {
"mtime": stat.st_mtime,
"size": stat.st_size,
"hash": hashlib.md5(content).hexdigest(),
"first_bytes": content[:100]
if content
else b"", # First 100 bytes as fingerprint
}
# Wait a moment to ensure different timestamps
time.sleep(0.1)
# Force refresh - should create new files with different metadata
await _create_tutorial_notebook(user_id, mock_session, force_refresh=True)
items_second = list(tutorial_data_dir.iterdir())
assert len(items_second) == 1, (
"Tutorial data directory should contain exactly one item after force refresh"
)
second_dir = items_second[0]
# Verify the structure is maintained after force refresh
notebook_file_second = second_dir / "tutorial.ipynb"
assert notebook_file_second.exists(), (
f"tutorial.ipynb should exist in {second_dir} after force refresh"
)
assert notebook_file_second.is_file(), "tutorial.ipynb should be a file after force refresh"
data_folder_second = second_dir / "data"
assert data_folder_second.exists(), (
f"data subfolder should exist in {second_dir} after force refresh"
)
assert data_folder_second.is_dir(), "data should be a directory after force refresh"
data_items_second = list(data_folder_second.iterdir())
assert len(data_items_second) > 0, (
f"data folder should still contain files after force refresh, but found {len(data_items_second)} items"
)
# Compare metadata to ensure files are actually different
files_with_changed_metadata = 0
for file_path in second_dir.rglob("*"):
if file_path.is_file():
relative_path = file_path.relative_to(second_dir)
relative_path_str = str(relative_path)
# File should exist from first creation
assert relative_path_str in first_creation_metadata, (
f"File {relative_path_str} missing from first creation"
)
old_metadata = first_creation_metadata[relative_path_str]
# Get new metadata
stat = file_path.stat()
with open(file_path, "rb") as f:
new_content = f.read()
new_metadata = {
"mtime": stat.st_mtime,
"size": stat.st_size,
"hash": hashlib.md5(new_content).hexdigest(),
"first_bytes": new_content[:100] if new_content else b"",
}
# Check if any metadata changed (indicating file was refreshed)
metadata_changed = (
new_metadata["mtime"] > old_metadata["mtime"] # Newer modification time
or new_metadata["hash"] != old_metadata["hash"] # Different content hash
or new_metadata["size"] != old_metadata["size"] # Different file size
or new_metadata["first_bytes"]
!= old_metadata["first_bytes"] # Different content
)
if metadata_changed:
files_with_changed_metadata += 1
# Assert that force refresh actually updated files
assert files_with_changed_metadata > 0, (
f"Force refresh should have updated at least some files, but all {len(first_creation_metadata)} "
f"files appear to have identical metadata. This suggests force refresh didn't work."
)
mock_session.commit.assert_called()
@pytest.mark.asyncio
async def test_tutorial_zip_url_accessibility(self):
"""Test that the actual tutorial zip URL is accessible (integration test)."""
try:
import requests
response = requests.get(
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
timeout=10,
)
response.raise_for_status()
# Verify it's a valid zip file by checking headers
assert response.headers.get("content-type") in [
"application/zip",
"application/octet-stream",
"application/x-zip-compressed",
] or response.content.startswith(b"PK") # Zip file signature
except Exception:
pytest.skip("Network request failed or zip not available - skipping integration test")

View file

@ -46,6 +46,7 @@ dependencies = [
"matplotlib>=3.8.3,<4", "matplotlib>=3.8.3,<4",
"networkx>=3.4.2,<4", "networkx>=3.4.2,<4",
"lancedb>=0.24.0,<1.0.0", "lancedb>=0.24.0,<1.0.0",
"nbformat>=5.7.0,<6.0.0",
"alembic>=1.13.3,<2", "alembic>=1.13.3,<2",
"pre-commit>=4.0.1,<5", "pre-commit>=4.0.1,<5",
"scikit-learn>=1.6.1,<2", "scikit-learn>=1.6.1,<2",

2
uv.lock generated
View file

@ -831,6 +831,7 @@ dependencies = [
{ name = "limits" }, { name = "limits" },
{ name = "litellm" }, { name = "litellm" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "nbformat" },
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "nltk" }, { name = "nltk" },
@ -1012,6 +1013,7 @@ requires-dist = [
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" }, { name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" }, { name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
{ name = "nbformat", specifier = ">=5.7.0,<6.0.0" },
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" }, { name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
{ name = "networkx", specifier = ">=3.4.2,<4" }, { name = "networkx", specifier = ">=3.4.2,<4" },
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" },