feat: add welcome tutorial notebook for new users (#1425)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> Update default tutorial: 1. Use tutorial from [notebook_tutorial branch](https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/tutorial.ipynb), specifically - it's .zip version with all necessary data files 2. Use Jupyter Notebook `Notebook` abstractions to read, and map `ipynb` into our Notebook model 3. Dynamically update starter notebook code blocks that reference starter data files, and swap them with local paths to downloaded copies 4. Test coverage | Before | After (storage backend = local) | After (s3) | |--------|---------------------------------|------------| | <img width="613" height="546" alt="Screenshot 2025-09-17 at 01 00 58" src="https://github.com/user-attachments/assets/20b59021-96c1-4a83-977f-e064324bd758" /> | <img width="1480" height="262" alt="Screenshot 2025-09-18 at 13 01 57" src="https://github.com/user-attachments/assets/bd56ea78-7c6a-42e3-ae3f-4157da231b2d" /> | <img width="1485" height="307" alt="Screenshot 2025-09-18 at 12 56 08" src="https://github.com/user-attachments/assets/248ae720-4c78-445a-ba8b-8a2991ed3f80" /> | ## File Replacements ### S3 Demo https://github.com/user-attachments/assets/bd46eec9-ef77-4f69-9ef0-e7d1612ff9b3 --- ### Local FS Demo https://github.com/user-attachments/assets/8251cea0-81b3-4cac-a968-9576c358f334 ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
bb124494c1
commit
f58ba86e7c
20 changed files with 1200 additions and 44 deletions
|
|
@ -47,6 +47,28 @@ BAML_LLM_API_VERSION=""
|
|||
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
|
||||
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
|
||||
|
||||
################################################################################
|
||||
# ☁️ Storage Backend Settings
|
||||
################################################################################
|
||||
# Configure storage backend (local filesystem or S3)
|
||||
# STORAGE_BACKEND="local" # Default: uses local filesystem
|
||||
#
|
||||
# -- To switch to S3 storage, uncomment and fill these: ---------------------
|
||||
# STORAGE_BACKEND="s3"
|
||||
# STORAGE_BUCKET_NAME="your-bucket-name"
|
||||
# AWS_REGION="us-east-1"
|
||||
# AWS_ACCESS_KEY_ID="your-access-key"
|
||||
# AWS_SECRET_ACCESS_KEY="your-secret-key"
|
||||
#
|
||||
# -- S3 Root Directories (optional) -----------------------------------------
|
||||
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
|
||||
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
|
||||
#
|
||||
# -- Cache Directory (auto-configured for S3) -------------------------------
|
||||
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
|
||||
# To override the automatic S3 cache location, uncomment:
|
||||
# CACHE_ROOT_DIRECTORY="s3://your-bucket/cognee/cache"
|
||||
|
||||
################################################################################
|
||||
# 🗄️ Relational database settings
|
||||
################################################################################
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -186,6 +186,7 @@ cognee/cache/
|
|||
# Default cognee system directory, used in development
|
||||
.cognee_system/
|
||||
.data_storage/
|
||||
.cognee_cache/
|
||||
.artifacts/
|
||||
.anon_id
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ class prune:
|
|||
await _prune_data()
|
||||
|
||||
@staticmethod
|
||||
async def prune_system(graph=True, vector=True, metadata=False):
|
||||
await _prune_system(graph, vector, metadata)
|
||||
async def prune_system(graph=True, vector=True, metadata=False, cache=True):
|
||||
await _prune_system(graph, vector, metadata, cache)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
|
|||
mark_sync_completed,
|
||||
mark_sync_failed,
|
||||
)
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger("sync")
|
||||
|
||||
|
|
@ -583,7 +584,9 @@ async def _check_hashes_diff(
|
|||
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
|
@ -630,7 +633,9 @@ async def _download_missing_files(
|
|||
|
||||
headers = {"X-Api-Key": auth_token}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
for file_hash in hashes_missing_on_local:
|
||||
try:
|
||||
# Download file from cloud by hash
|
||||
|
|
@ -749,7 +754,9 @@ async def _upload_missing_files(
|
|||
|
||||
headers = {"X-Api-Key": auth_token}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
for file_info in files_to_upload:
|
||||
try:
|
||||
file_dir = os.path.dirname(file_info.raw_data_location)
|
||||
|
|
@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
|
|||
logger.info("Pruning cloud dataset to match local state")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
|
@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
|
|||
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
|
|
|
|||
|
|
@ -10,13 +10,27 @@ import pydantic
|
|||
class BaseConfig(BaseSettings):
|
||||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||
cache_root_directory: str = get_absolute_path(".cognee_cache")
|
||||
monitoring_tool: object = Observer.LANGFUSE
|
||||
|
||||
@pydantic.model_validator(mode="after")
|
||||
def validate_paths(self):
|
||||
# Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
|
||||
# I'll remove this after we update documentation for S3 storage
|
||||
# Auto-configure cache root directory for S3 storage if not explicitly set
|
||||
storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
|
||||
cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
|
||||
|
||||
if storage_backend == "s3" and not cache_root_env:
|
||||
# Auto-generate S3 cache path when using S3 storage
|
||||
bucket_name = os.getenv("STORAGE_BUCKET_NAME")
|
||||
if bucket_name:
|
||||
self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
|
||||
|
||||
# Require absolute paths for root directories
|
||||
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
||||
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
||||
self.cache_root_directory = ensure_absolute_path(self.cache_root_directory)
|
||||
return self
|
||||
|
||||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
|
|
@ -31,6 +45,7 @@ class BaseConfig(BaseSettings):
|
|||
"data_root_directory": self.data_root_directory,
|
||||
"system_root_directory": self.system_root_directory,
|
||||
"monitoring_tool": self.monitoring_tool,
|
||||
"cache_root_directory": self.cache_root_directory,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import aiohttp
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create an aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
self._session = aiohttp.ClientSession(connector=connector)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
|
|||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger("OllamaEmbeddingEngine")
|
||||
|
||||
|
|
@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
|
|
|
|||
|
|
@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
|
|||
if os.path.exists(full_file_path):
|
||||
os.remove(full_file_path)
|
||||
|
||||
def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||
"""
|
||||
List all files in the specified directory.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- directory_path (str): The directory path to list files from
|
||||
- recursive (bool): If True, list files recursively in subdirectories
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[str]: List of file paths relative to the storage root
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
if directory_path:
|
||||
full_directory_path = os.path.join(parsed_storage_path, directory_path)
|
||||
else:
|
||||
full_directory_path = parsed_storage_path
|
||||
|
||||
directory_pathlib = Path(full_directory_path)
|
||||
|
||||
if not directory_pathlib.exists() or not directory_pathlib.is_dir():
|
||||
return []
|
||||
|
||||
files = []
|
||||
|
||||
if recursive:
|
||||
# Use rglob for recursive search
|
||||
for file_path in directory_pathlib.rglob("*"):
|
||||
if file_path.is_file():
|
||||
# Get relative path from storage root
|
||||
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||
# Normalize path separators for consistency
|
||||
relative_path = relative_path.replace(os.sep, "/")
|
||||
files.append(relative_path)
|
||||
else:
|
||||
# Use iterdir for just immediate directory
|
||||
for file_path in directory_pathlib.iterdir():
|
||||
if file_path.is_file():
|
||||
# Get relative path from storage root
|
||||
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||
# Normalize path separators for consistency
|
||||
relative_path = relative_path.replace(os.sep, "/")
|
||||
files.append(relative_path)
|
||||
|
||||
return files
|
||||
|
||||
def remove_all(self, tree_path: str = None):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
|
|
|
|||
|
|
@ -155,21 +155,19 @@ class S3FileStorage(Storage):
|
|||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
For S3 storage, this is a no-op since directories are created implicitly
|
||||
when files are written to paths. S3 doesn't have actual directories,
|
||||
just object keys with prefixes that appear as directories.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if not directory_path.strip():
|
||||
directory_path = self.storage_path.replace("s3://", "")
|
||||
|
||||
def ensure_directory():
|
||||
if not self.s3.exists(directory_path):
|
||||
self.s3.makedirs(directory_path, exist_ok=True)
|
||||
|
||||
await run_async(ensure_directory)
|
||||
# In S3, directories don't exist as separate entities - they're just prefixes
|
||||
# When you write a file to s3://bucket/path/to/file.txt, the "directories"
|
||||
# path/ and path/to/ are implicitly created. No explicit action needed.
|
||||
pass
|
||||
|
||||
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||
"""
|
||||
|
|
@ -213,6 +211,55 @@ class S3FileStorage(Storage):
|
|||
|
||||
await run_async(remove_file)
|
||||
|
||||
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||
"""
|
||||
List all files in the specified directory.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- directory_path (str): The directory path to list files from
|
||||
- recursive (bool): If True, list files recursively in subdirectories
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[str]: List of file paths relative to the storage root
|
||||
"""
|
||||
|
||||
def list_files_sync():
|
||||
if directory_path:
|
||||
# Combine storage path with directory path
|
||||
full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
|
||||
else:
|
||||
full_path = self.storage_path.replace("s3://", "")
|
||||
|
||||
if recursive:
|
||||
# Use ** for recursive search
|
||||
pattern = f"{full_path}/**"
|
||||
else:
|
||||
# Just files in the immediate directory
|
||||
pattern = f"{full_path}/*"
|
||||
|
||||
# Use s3fs glob to find files
|
||||
try:
|
||||
all_paths = self.s3.glob(pattern)
|
||||
# Filter to only files (not directories)
|
||||
files = [path for path in all_paths if self.s3.isfile(path)]
|
||||
|
||||
# Convert back to relative paths from storage root
|
||||
storage_prefix = self.storage_path.replace("s3://", "")
|
||||
relative_files = []
|
||||
for file_path in files:
|
||||
if file_path.startswith(storage_prefix):
|
||||
relative_path = file_path[len(storage_prefix) :].lstrip("/")
|
||||
relative_files.append(relative_path)
|
||||
|
||||
return relative_files
|
||||
except Exception:
|
||||
# If directory doesn't exist or other error, return empty list
|
||||
return []
|
||||
|
||||
return await run_async(list_files_sync)
|
||||
|
||||
async def remove_all(self, tree_path: str):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
|
|
|
|||
|
|
@ -135,6 +135,24 @@ class StorageManager:
|
|||
else:
|
||||
return self.storage.remove(file_path)
|
||||
|
||||
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||
"""
|
||||
List all files in the specified directory.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- directory_path (str): The directory path to list files from
|
||||
- recursive (bool): If True, list files recursively in subdirectories
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- list[str]: List of file paths relative to the storage root
|
||||
"""
|
||||
if inspect.iscoroutinefunction(self.storage.list_files):
|
||||
return await self.storage.list_files(directory_path, recursive)
|
||||
else:
|
||||
return self.storage.list_files(directory_path, recursive)
|
||||
|
||||
async def remove_all(self, tree_path: str = None):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import aiohttp
|
||||
|
||||
from cognee.modules.cloud.exceptions import CloudConnectionError
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
|
||||
async def check_api_key(auth_token: str):
|
||||
|
|
@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
|
|||
headers = {"X-Api-Key": auth_token}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.shared.cache import delete_cache
|
||||
|
||||
|
||||
async def prune_system(graph=True, vector=True, metadata=True):
|
||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||
if graph:
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
|
|
@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
|
|||
if metadata:
|
||||
db_engine = get_relational_engine()
|
||||
await db_engine.delete_database()
|
||||
|
||||
if cache:
|
||||
await delete_cache()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session
|
|||
from ..models.Notebook import Notebook, NotebookCell
|
||||
|
||||
|
||||
async def _create_tutorial_notebook(
|
||||
user_id: UUID, session: AsyncSession, force_refresh: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Create the default tutorial notebook for new users.
|
||||
Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
|
||||
"""
|
||||
TUTORIAL_ZIP_URL = (
|
||||
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
|
||||
)
|
||||
|
||||
try:
|
||||
# Create notebook from remote zip file (includes notebook + data files)
|
||||
notebook = await Notebook.from_ipynb_zip_url(
|
||||
zip_url=TUTORIAL_ZIP_URL,
|
||||
owner_id=user_id,
|
||||
notebook_filename="tutorial.ipynb",
|
||||
name="Python Development with Cognee Tutorial 🧠",
|
||||
deletable=False,
|
||||
force=force_refresh,
|
||||
)
|
||||
|
||||
# Add to session and commit
|
||||
session.add(notebook)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def create_notebook(
|
||||
user_id: UUID,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,24 @@
|
|||
import json
|
||||
from typing import List, Literal
|
||||
import nbformat
|
||||
import asyncio
|
||||
from nbformat.notebooknode import NotebookNode
|
||||
from typing import List, Literal, Optional, cast, Tuple
|
||||
from uuid import uuid4, UUID as UUID_t
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from cognee.shared.cache import (
|
||||
download_and_extract_zip,
|
||||
get_tutorial_data_dir,
|
||||
generate_content_hash,
|
||||
)
|
||||
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
|
||||
class NotebookCell(BaseModel):
|
||||
|
|
@ -51,3 +62,197 @@ class Notebook(Base):
|
|||
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@classmethod
|
||||
async def from_ipynb_zip_url(
|
||||
cls,
|
||||
zip_url: str,
|
||||
owner_id: UUID_t,
|
||||
notebook_filename: str = "tutorial.ipynb",
|
||||
name: Optional[str] = None,
|
||||
deletable: bool = True,
|
||||
force: bool = False,
|
||||
) -> "Notebook":
|
||||
"""
|
||||
Create a Notebook instance from a remote zip file containing notebook + data files.
|
||||
|
||||
Args:
|
||||
zip_url: Remote URL to fetch the .zip file from
|
||||
owner_id: UUID of the notebook owner
|
||||
notebook_filename: Name of the .ipynb file within the zip
|
||||
name: Optional custom name for the notebook
|
||||
deletable: Whether the notebook can be deleted
|
||||
force: If True, re-download even if already cached
|
||||
|
||||
Returns:
|
||||
Notebook instance
|
||||
"""
|
||||
# Generate a cache key based on the zip URL
|
||||
content_hash = generate_content_hash(zip_url, notebook_filename)
|
||||
|
||||
# Download and extract the zip file to tutorial_data/{content_hash}
|
||||
try:
|
||||
extracted_cache_dir = await download_and_extract_zip(
|
||||
url=zip_url,
|
||||
cache_dir_name=f"tutorial_data/{content_hash}",
|
||||
version_or_hash=content_hash,
|
||||
force=force,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
|
||||
|
||||
# Use cache system to access the notebook file
|
||||
from cognee.shared.cache import cache_file_exists, read_cache_file
|
||||
|
||||
notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
|
||||
|
||||
# Check if the notebook file exists in cache
|
||||
if not await cache_file_exists(notebook_file_path):
|
||||
raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
|
||||
|
||||
# Read and parse the notebook using cache system
|
||||
async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
|
||||
notebook_content = await asyncio.to_thread(f.read)
|
||||
notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
|
||||
|
||||
# Update file paths in notebook cells to point to actual cached data files
|
||||
await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
|
||||
|
||||
return notebook
|
||||
|
||||
@staticmethod
|
||||
async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
|
||||
"""
|
||||
Update file paths in code cells to use actual cached data files.
|
||||
Works with both local filesystem and S3 storage.
|
||||
|
||||
Args:
|
||||
notebook: Parsed Notebook instance with cells to update
|
||||
cache_dir: Path to the cached tutorial directory containing data files
|
||||
"""
|
||||
import re
|
||||
from cognee.shared.cache import list_cache_files, cache_file_exists
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Look for data files in the data subdirectory
|
||||
data_dir = f"{cache_dir}/data"
|
||||
|
||||
try:
|
||||
# Get all data files in the cache directory using cache system
|
||||
data_files = {}
|
||||
if await cache_file_exists(data_dir):
|
||||
file_list = await list_cache_files(data_dir)
|
||||
else:
|
||||
file_list = []
|
||||
|
||||
for file_path in file_list:
|
||||
# Extract just the filename
|
||||
filename = file_path.split("/")[-1]
|
||||
# Use the file path as provided by cache system
|
||||
data_files[filename] = file_path
|
||||
|
||||
except Exception as e:
|
||||
# If we can't list files, skip updating paths
|
||||
logger.error(f"Error listing data files in {data_dir}: {e}")
|
||||
return
|
||||
|
||||
# Pattern to match file://data/filename patterns in code cells
|
||||
file_pattern = r'"file://data/([^"]+)"'
|
||||
|
||||
def replace_path(match):
|
||||
filename = match.group(1)
|
||||
if filename in data_files:
|
||||
file_path = data_files[filename]
|
||||
# For local filesystem, preserve file:// prefix
|
||||
if not file_path.startswith("s3://"):
|
||||
return f'"file://{file_path}"'
|
||||
else:
|
||||
# For S3, return the S3 URL as-is
|
||||
return f'"{file_path}"'
|
||||
return match.group(0) # Keep original if file not found
|
||||
|
||||
# Update only code cells
|
||||
updated_cells = 0
|
||||
for cell in notebook.cells:
|
||||
if cell.type == "code":
|
||||
original_content = cell.content
|
||||
# Update file paths in the cell content
|
||||
cell.content = re.sub(file_pattern, replace_path, cell.content)
|
||||
if original_content != cell.content:
|
||||
updated_cells += 1
|
||||
|
||||
# Log summary of updates (useful for monitoring)
|
||||
if updated_cells > 0:
|
||||
logger.info(f"Updated file paths in {updated_cells} notebook cells")
|
||||
|
||||
@classmethod
|
||||
def from_ipynb_string(
|
||||
cls,
|
||||
notebook_content: str,
|
||||
owner_id: UUID_t,
|
||||
name: Optional[str] = None,
|
||||
deletable: bool = True,
|
||||
) -> "Notebook":
|
||||
"""
|
||||
Create a Notebook instance from Jupyter notebook string content.
|
||||
|
||||
Args:
|
||||
notebook_content: Raw Jupyter notebook content as string
|
||||
owner_id: UUID of the notebook owner
|
||||
name: Optional custom name for the notebook
|
||||
deletable: Whether the notebook can be deleted
|
||||
|
||||
Returns:
|
||||
Notebook instance ready to be saved to database
|
||||
"""
|
||||
# Parse and validate the Jupyter notebook using nbformat
|
||||
# Note: nbformat.reads() has loose typing, so we cast to NotebookNode
|
||||
jupyter_nb = cast(
|
||||
NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
|
||||
)
|
||||
|
||||
# Convert Jupyter cells to NotebookCell objects
|
||||
cells = []
|
||||
for jupyter_cell in jupyter_nb.cells:
|
||||
# Each cell is also a NotebookNode with dynamic attributes
|
||||
cell = cast(NotebookNode, jupyter_cell)
|
||||
# Skip raw cells as they're not supported in our model
|
||||
if cell.cell_type == "raw":
|
||||
continue
|
||||
|
||||
# Get the source content
|
||||
content = cell.source
|
||||
|
||||
# Generate a name based on content or cell index
|
||||
cell_name = cls._generate_cell_name(cell)
|
||||
|
||||
# Map cell types (jupyter uses "code"/"markdown", we use same)
|
||||
cell_type = "code" if cell.cell_type == "code" else "markdown"
|
||||
|
||||
cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
|
||||
|
||||
# Extract notebook name from metadata if not provided
|
||||
if name is None:
|
||||
kernelspec = jupyter_nb.metadata.get("kernelspec", {})
|
||||
name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
|
||||
|
||||
return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
|
||||
|
||||
@staticmethod
|
||||
def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
|
||||
"""Generate a meaningful name for a notebook cell using nbformat cell."""
|
||||
if jupyter_cell.cell_type == "markdown":
|
||||
# Try to extract a title from markdown headers
|
||||
content = jupyter_cell.source
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if lines and lines[0].startswith("#"):
|
||||
# Extract header text, clean it up
|
||||
header = lines[0].lstrip("#").strip()
|
||||
return header[:50] if len(header) > 50 else header
|
||||
else:
|
||||
return "Markdown Cell"
|
||||
else:
|
||||
return "Code Cell"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.notebooks.methods import create_notebook
|
||||
from cognee.modules.notebooks.models.Notebook import NotebookCell
|
||||
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||
from cognee.modules.users.exceptions import TenantNotFoundError
|
||||
from cognee.modules.users.get_user_manager import get_user_manager_context
|
||||
from cognee.modules.users.get_user_db import get_user_db_context
|
||||
|
|
@ -60,26 +61,7 @@ async def create_user(
|
|||
if auto_login:
|
||||
await session.refresh(user)
|
||||
|
||||
await create_notebook(
|
||||
user_id=user.id,
|
||||
notebook_name="Welcome to cognee 🧠",
|
||||
cells=[
|
||||
NotebookCell(
|
||||
id=uuid4(),
|
||||
name="Welcome",
|
||||
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
|
||||
type="markdown",
|
||||
),
|
||||
NotebookCell(
|
||||
id=uuid4(),
|
||||
name="Example",
|
||||
content="",
|
||||
type="code",
|
||||
),
|
||||
],
|
||||
deletable=False,
|
||||
session=session,
|
||||
)
|
||||
await _create_tutorial_notebook(user.id, session)
|
||||
|
||||
return user
|
||||
except UserAlreadyExists as error:
|
||||
|
|
|
|||
346
cognee/shared/cache.py
Normal file
346
cognee/shared/cache.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"""
|
||||
Storage-aware cache management utilities for Cognee.
|
||||
|
||||
This module provides cache functionality that works with both local and cloud storage
|
||||
backends (like S3) through the StorageManager abstraction.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import zipfile
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
import aiohttp
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||
from cognee.infrastructure.files.storage.StorageManager import StorageManager
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StorageAwareCache:
|
||||
"""
|
||||
A cache manager that works with different storage backends (local, S3, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, cache_subdir: str = "cache"):
|
||||
"""
|
||||
Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
cache_subdir: Subdirectory name within the system root for caching
|
||||
"""
|
||||
self.base_config = get_base_config()
|
||||
# Since we're using cache_root_directory, don't add extra cache prefix
|
||||
self.cache_base_path = ""
|
||||
self.storage_manager: StorageManager = get_file_storage(
|
||||
self.base_config.cache_root_directory
|
||||
)
|
||||
|
||||
# Print absolute path
|
||||
storage_path = self.storage_manager.storage.storage_path
|
||||
if storage_path.startswith("s3://"):
|
||||
absolute_path = storage_path # S3 paths are already absolute
|
||||
else:
|
||||
import os
|
||||
|
||||
absolute_path = os.path.abspath(storage_path)
|
||||
logger.info(f"Storage manager absolute path: {absolute_path}")
|
||||
|
||||
async def get_cache_dir(self) -> str:
|
||||
"""Get the base cache directory path."""
|
||||
cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty
|
||||
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||
return cache_path
|
||||
|
||||
async def get_cache_subdir(self, name: str) -> str:
|
||||
"""Get a specific cache subdirectory."""
|
||||
if self.cache_base_path:
|
||||
cache_path = f"{self.cache_base_path}/{name}"
|
||||
else:
|
||||
cache_path = name
|
||||
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||
|
||||
# Return the absolute path based on storage system
|
||||
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||
return cache_path
|
||||
elif hasattr(self.storage_manager.storage, "storage_path"):
|
||||
return f"{self.storage_manager.storage.storage_path}/{cache_path}"
|
||||
else:
|
||||
# Fallback for other storage types
|
||||
return cache_path
|
||||
|
||||
async def delete_cache(self):
|
||||
"""Delete the entire cache directory."""
|
||||
logger.info("Deleting cache...")
|
||||
try:
|
||||
await self.storage_manager.remove_all(self.cache_base_path)
|
||||
logger.info("✓ Cache deleted successfully!")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache: {e}")
|
||||
raise
|
||||
|
||||
async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool:
|
||||
"""Check if cached content is valid for the given version/hash."""
|
||||
version_file = f"{cache_dir}/version.txt"
|
||||
|
||||
if not await self.storage_manager.file_exists(version_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with self.storage_manager.open(version_file, "r") as f:
|
||||
cached_version = (await asyncio.to_thread(f.read)).strip()
|
||||
return cached_version == version_or_hash
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking cache validity: {e}")
|
||||
return False
|
||||
|
||||
async def _clear_cache(self, cache_dir: str) -> None:
|
||||
"""Clear a cache directory."""
|
||||
try:
|
||||
await self.storage_manager.remove_all(cache_dir)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error clearing cache directory {cache_dir}: {e}")
|
||||
|
||||
async def _check_remote_content_freshness(
|
||||
self, url: str, cache_dir: str
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if remote content is fresher than cached version using HTTP headers.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_fresh: bool, new_identifier: Optional[str])
|
||||
"""
|
||||
try:
|
||||
# Make a HEAD request to check headers without downloading
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Try ETag first (most reliable)
|
||||
etag = response.headers.get("ETag", "").strip('"')
|
||||
last_modified = response.headers.get("Last-Modified", "")
|
||||
|
||||
# Use ETag if available, otherwise Last-Modified
|
||||
remote_identifier = etag if etag else last_modified
|
||||
|
||||
if not remote_identifier:
|
||||
logger.debug("No freshness headers available, cannot check for updates")
|
||||
return True, None # Assume fresh if no headers
|
||||
|
||||
# Check cached identifier
|
||||
identifier_file = f"{cache_dir}/content_id.txt"
|
||||
if await self.storage_manager.file_exists(identifier_file):
|
||||
async with self.storage_manager.open(identifier_file, "r") as f:
|
||||
cached_identifier = (await asyncio.to_thread(f.read)).strip()
|
||||
if cached_identifier == remote_identifier:
|
||||
logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)")
|
||||
return True, None
|
||||
else:
|
||||
logger.info(
|
||||
f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)"
|
||||
)
|
||||
return False, remote_identifier
|
||||
else:
|
||||
# No cached identifier, treat as stale
|
||||
return False, remote_identifier
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check remote freshness: {e}")
|
||||
return True, None # Assume fresh if we can't check
|
||||
|
||||
async def download_and_extract_zip(
|
||||
self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Download a zip file and extract it to cache directory with content freshness checking.
|
||||
|
||||
Args:
|
||||
url: URL to download zip file from
|
||||
cache_subdir_name: Name of the cache subdirectory
|
||||
version_or_hash: Version string or content hash for cache validation
|
||||
force: If True, re-download even if already cached
|
||||
|
||||
Returns:
|
||||
Path to the cached directory
|
||||
"""
|
||||
cache_dir = await self.get_cache_subdir(cache_subdir_name)
|
||||
|
||||
# Check if already cached and valid
|
||||
if not force and await self._is_cache_valid(cache_dir, version_or_hash):
|
||||
# Also check if remote content has changed
|
||||
is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir)
|
||||
if is_fresh:
|
||||
logger.debug(f"Content already cached and fresh for version {version_or_hash}")
|
||||
return cache_dir
|
||||
else:
|
||||
logger.info("Cached content is stale, updating...")
|
||||
|
||||
# Clear old cache if it exists
|
||||
await self._clear_cache(cache_dir)
|
||||
|
||||
logger.info(f"Downloading content from {url}...")
|
||||
|
||||
# Download the zip file
|
||||
zip_content = BytesIO()
|
||||
etag = ""
|
||||
last_modified = ""
|
||||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Extract headers before consuming response
|
||||
etag = response.headers.get("ETag", "").strip('"')
|
||||
last_modified = response.headers.get("Last-Modified", "")
|
||||
|
||||
# Read the response content
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
zip_content.write(chunk)
|
||||
zip_content.seek(0)
|
||||
|
||||
# Extract the archive
|
||||
await self.storage_manager.ensure_directory_exists(cache_dir)
|
||||
|
||||
# Extract files and store them using StorageManager
|
||||
with zipfile.ZipFile(zip_content, "r") as zip_file:
|
||||
for file_info in zip_file.infolist():
|
||||
if file_info.is_dir():
|
||||
# Create directory
|
||||
dir_path = f"{cache_dir}/{file_info.filename}"
|
||||
await self.storage_manager.ensure_directory_exists(dir_path)
|
||||
else:
|
||||
# Extract and store file
|
||||
file_data = zip_file.read(file_info.filename)
|
||||
file_path = f"{cache_dir}/{file_info.filename}"
|
||||
await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True)
|
||||
|
||||
# Write version info for future cache validation
|
||||
version_file = f"{cache_dir}/version.txt"
|
||||
await self.storage_manager.store(version_file, version_or_hash, overwrite=True)
|
||||
|
||||
# Store content identifier from response headers for freshness checking
|
||||
content_identifier = etag if etag else last_modified
|
||||
|
||||
if content_identifier:
|
||||
identifier_file = f"{cache_dir}/content_id.txt"
|
||||
await self.storage_manager.store(identifier_file, content_identifier, overwrite=True)
|
||||
logger.debug(f"Stored content identifier: {content_identifier[:20]}...")
|
||||
|
||||
logger.info("✓ Content downloaded and cached successfully!")
|
||||
return cache_dir
|
||||
|
||||
async def file_exists(self, file_path: str) -> bool:
|
||||
"""Check if a file exists in cache storage."""
|
||||
return await self.storage_manager.file_exists(file_path)
|
||||
|
||||
async def read_file(self, file_path: str, encoding: str = "utf-8"):
|
||||
"""Read a file from cache storage."""
|
||||
return self.storage_manager.open(file_path, encoding=encoding)
|
||||
|
||||
async def list_files(self, directory_path: str):
|
||||
"""List files in a cache directory."""
|
||||
try:
|
||||
file_list = await self.storage_manager.list_files(directory_path)
|
||||
|
||||
# For S3 storage, convert relative paths to full S3 URLs
|
||||
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||
full_paths = []
|
||||
for file_path in file_list:
|
||||
full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}"
|
||||
full_paths.append(full_s3_path)
|
||||
return full_paths
|
||||
else:
|
||||
# For local storage, return absolute paths
|
||||
storage_path = self.storage_manager.storage.storage_path
|
||||
if not storage_path.startswith("/"):
|
||||
import os
|
||||
|
||||
storage_path = os.path.abspath(storage_path)
|
||||
|
||||
full_paths = []
|
||||
for file_path in file_list:
|
||||
if file_path.startswith("/"):
|
||||
full_paths.append(file_path) # Already absolute
|
||||
else:
|
||||
full_paths.append(f"{storage_path}/{file_path}")
|
||||
return full_paths
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error listing files in {directory_path}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Convenience functions that maintain API compatibility
|
||||
_cache_manager = None
|
||||
|
||||
|
||||
def get_cache_manager() -> StorageAwareCache:
|
||||
"""Get a singleton cache manager instance."""
|
||||
global _cache_manager
|
||||
if _cache_manager is None:
|
||||
_cache_manager = StorageAwareCache()
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def generate_content_hash(url: str, additional_data: str = "") -> str:
|
||||
"""Generate a content hash from URL and optional additional data."""
|
||||
content = f"{url}:{additional_data}"
|
||||
return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability
|
||||
|
||||
|
||||
# Async wrapper functions for backward compatibility
|
||||
async def delete_cache():
|
||||
"""Delete the Cognee cache directory."""
|
||||
cache_manager = get_cache_manager()
|
||||
await cache_manager.delete_cache()
|
||||
|
||||
|
||||
async def get_cognee_cache_dir() -> str:
|
||||
"""Get the base Cognee cache directory."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.get_cache_dir()
|
||||
|
||||
|
||||
async def get_cache_subdir(name: str) -> str:
|
||||
"""Get a specific cache subdirectory."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.get_cache_subdir(name)
|
||||
|
||||
|
||||
async def download_and_extract_zip(
|
||||
url: str, cache_dir_name: str, version_or_hash: str, force: bool = False
|
||||
) -> str:
|
||||
"""Download a zip file and extract it to cache directory."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force)
|
||||
|
||||
|
||||
async def get_tutorial_data_dir() -> str:
|
||||
"""Get the tutorial data cache directory."""
|
||||
return await get_cache_subdir("tutorial_data")
|
||||
|
||||
|
||||
# Cache file operations
|
||||
async def cache_file_exists(file_path: str) -> bool:
|
||||
"""Check if a file exists in cache storage."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.file_exists(file_path)
|
||||
|
||||
|
||||
async def read_cache_file(file_path: str, encoding: str = "utf-8"):
|
||||
"""Read a file from cache storage."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.read_file(file_path, encoding)
|
||||
|
||||
|
||||
async def list_cache_files(directory_path: str):
|
||||
"""List files in a cache directory."""
|
||||
cache_manager = get_cache_manager()
|
||||
return await cache_manager.list_files(directory_path)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""This module contains utility functions for the cognee."""
|
||||
|
||||
import os
|
||||
import ssl
|
||||
import requests
|
||||
from datetime import datetime, timezone
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -18,6 +19,17 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
proxy_url = "https://test.prometh.ai"
|
||||
|
||||
|
||||
def create_secure_ssl_context() -> ssl.SSLContext:
|
||||
"""
|
||||
Create a secure SSL context.
|
||||
|
||||
By default, use the system's certificate store.
|
||||
If users report SSL issues, I'm keeping this open in case we need to switch to:
|
||||
ssl.create_default_context(cafile=certifi.where())
|
||||
"""
|
||||
return ssl.create_default_context()
|
||||
|
||||
|
||||
def get_entities(tagged_tokens):
|
||||
import nltk
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,399 @@
|
|||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import hashlib
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
from cognee.shared.cache import get_tutorial_data_dir
|
||||
|
||||
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Module-level fixtures available to all test classes
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_jupyter_notebook():
|
||||
"""Sample Jupyter notebook content for testing."""
|
||||
return {
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": ["# Tutorial Introduction\n", "\n", "This is a tutorial notebook."],
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": ["import cognee\n", "print('Hello, Cognee!')"],
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": ["## Step 1: Data Ingestion\n", "\n", "Let's add some data."],
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": None,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": ["# Add your data here\n", "# await cognee.add('data.txt')"],
|
||||
},
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"metadata": {},
|
||||
"source": ["This is a raw cell that should be skipped"],
|
||||
},
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4,
|
||||
}
|
||||
|
||||
|
||||
class TestTutorialNotebookCreation:
|
||||
"""Test cases for tutorial notebook creation functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_string_success(self, sample_jupyter_notebook):
|
||||
"""Test successful creation of notebook from JSON string."""
|
||||
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||
user_id = uuid4()
|
||||
|
||||
notebook = Notebook.from_ipynb_string(
|
||||
notebook_content=notebook_json, owner_id=user_id, name="String Test Notebook"
|
||||
)
|
||||
|
||||
assert notebook.owner_id == user_id
|
||||
assert notebook.name == "String Test Notebook"
|
||||
assert len(notebook.cells) == 4 # Should skip the raw cell
|
||||
assert notebook.cells[0].type == "markdown"
|
||||
assert notebook.cells[1].type == "code"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_cell_name_generation(self, sample_jupyter_notebook):
|
||||
"""Test that cell names are generated correctly from markdown headers."""
|
||||
user_id = uuid4()
|
||||
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||
|
||||
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||
|
||||
# Check markdown header extraction
|
||||
assert notebook.cells[0].name == "Tutorial Introduction"
|
||||
assert notebook.cells[2].name == "Step 1: Data Ingestion"
|
||||
|
||||
# Check code cell naming
|
||||
assert notebook.cells[1].name == "Code Cell"
|
||||
assert notebook.cells[3].name == "Code Cell"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_string_with_default_name(self, sample_jupyter_notebook):
|
||||
"""Test notebook creation uses kernelspec display_name when no name provided."""
|
||||
user_id = uuid4()
|
||||
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||
|
||||
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||
|
||||
assert notebook.name == "Python 3" # From kernelspec.display_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_string_fallback_name(self):
|
||||
"""Test fallback naming when kernelspec is missing."""
|
||||
minimal_notebook = {
|
||||
"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Test"]}],
|
||||
"metadata": {}, # No kernelspec
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4,
|
||||
}
|
||||
|
||||
user_id = uuid4()
|
||||
notebook_json = json.dumps(minimal_notebook)
|
||||
|
||||
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||
|
||||
assert notebook.name == "Imported Notebook" # Fallback name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_string_invalid_json(self):
|
||||
"""Test error handling for invalid JSON."""
|
||||
user_id = uuid4()
|
||||
invalid_json = "{ invalid json content"
|
||||
|
||||
from nbformat.reader import NotJSONError
|
||||
|
||||
with pytest.raises(NotJSONError):
|
||||
Notebook.from_ipynb_string(notebook_content=invalid_json, owner_id=user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(Notebook, "from_ipynb_zip_url")
|
||||
async def test_create_tutorial_notebook_error_propagated(self, mock_from_zip_url, mock_session):
|
||||
"""Test that errors are propagated when zip fetch fails."""
|
||||
user_id = uuid4()
|
||||
mock_from_zip_url.side_effect = Exception("Network error")
|
||||
|
||||
# Should raise the exception (not catch it)
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
await _create_tutorial_notebook(user_id, mock_session)
|
||||
|
||||
# Verify error handling path was taken
|
||||
mock_from_zip_url.assert_called_once()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_generate_cell_name_code_cell(self):
|
||||
"""Test cell name generation for code cells."""
|
||||
from nbformat.notebooknode import NotebookNode
|
||||
|
||||
mock_cell = NotebookNode(
|
||||
{"cell_type": "code", "source": 'import pandas as pd\nprint("Hello world")'}
|
||||
)
|
||||
|
||||
result = Notebook._generate_cell_name(mock_cell)
|
||||
assert result == "Code Cell"
|
||||
|
||||
|
||||
class TestTutorialNotebookZipFunctionality:
|
||||
"""Test cases for zip-based tutorial functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_zip_url_missing_notebook(
|
||||
self,
|
||||
):
|
||||
"""Test error handling when notebook file is missing from zip."""
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match="Notebook file 'super_random_tutorial_name.ipynb' not found in zip",
|
||||
):
|
||||
await Notebook.from_ipynb_zip_url(
|
||||
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||
owner_id=user_id,
|
||||
notebook_filename="super_random_tutorial_name.ipynb",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notebook_from_ipynb_zip_url_download_failure(self):
|
||||
"""Test error handling when zip download fails."""
|
||||
user_id = uuid4()
|
||||
with pytest.raises(RuntimeError, match="Failed to download tutorial zip"):
|
||||
await Notebook.from_ipynb_zip_url(
|
||||
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/nonexistent_tutorial_name.zip",
|
||||
owner_id=user_id,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tutorial_notebook_zip_success(self, mock_session):
|
||||
"""Test successful tutorial notebook creation with zip."""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
# Check that tutorial data directory is empty using storage-aware method
|
||||
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||
if tutorial_data_dir.exists():
|
||||
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||
|
||||
await _create_tutorial_notebook(user_id, mock_session)
|
||||
|
||||
items = list(tutorial_data_dir.iterdir())
|
||||
assert len(items) == 1, "Tutorial data directory should contain exactly one item"
|
||||
assert items[0].is_dir(), "Tutorial data directory item should be a directory"
|
||||
|
||||
# Verify the structure inside the tutorial directory
|
||||
tutorial_dir = items[0]
|
||||
|
||||
# Check for tutorial.ipynb file
|
||||
notebook_file = tutorial_dir / "tutorial.ipynb"
|
||||
assert notebook_file.exists(), f"tutorial.ipynb should exist in {tutorial_dir}"
|
||||
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||
|
||||
# Check for data subfolder with contents
|
||||
data_folder = tutorial_dir / "data"
|
||||
assert data_folder.exists(), f"data subfolder should exist in {tutorial_dir}"
|
||||
assert data_folder.is_dir(), "data should be a directory"
|
||||
|
||||
data_items = list(data_folder.iterdir())
|
||||
assert len(data_items) > 0, (
|
||||
f"data folder should contain files, but found {len(data_items)} items"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tutorial_notebook_with_force_refresh(self, mock_session):
|
||||
"""Test tutorial notebook creation with force refresh."""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
# Check that tutorial data directory is empty using storage-aware method
|
||||
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||
if tutorial_data_dir.exists():
|
||||
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||
|
||||
# First creation (without force refresh)
|
||||
await _create_tutorial_notebook(user_id, mock_session, force_refresh=False)
|
||||
|
||||
items_first = list(tutorial_data_dir.iterdir())
|
||||
assert len(items_first) == 1, (
|
||||
"Tutorial data directory should contain exactly one item after first creation"
|
||||
)
|
||||
first_dir = items_first[0]
|
||||
assert first_dir.is_dir(), "Tutorial data directory item should be a directory"
|
||||
|
||||
# Verify the structure inside the tutorial directory (first creation)
|
||||
notebook_file = first_dir / "tutorial.ipynb"
|
||||
assert notebook_file.exists(), f"tutorial.ipynb should exist in {first_dir}"
|
||||
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||
|
||||
data_folder = first_dir / "data"
|
||||
assert data_folder.exists(), f"data subfolder should exist in {first_dir}"
|
||||
assert data_folder.is_dir(), "data should be a directory"
|
||||
|
||||
data_items = list(data_folder.iterdir())
|
||||
assert len(data_items) > 0, (
|
||||
f"data folder should contain files, but found {len(data_items)} items"
|
||||
)
|
||||
|
||||
# Capture metadata from first creation
|
||||
|
||||
first_creation_metadata = {}
|
||||
|
||||
for file_path in first_dir.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(first_dir)
|
||||
stat = file_path.stat()
|
||||
|
||||
# Store multiple metadata points
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
first_creation_metadata[str(relative_path)] = {
|
||||
"mtime": stat.st_mtime,
|
||||
"size": stat.st_size,
|
||||
"hash": hashlib.md5(content).hexdigest(),
|
||||
"first_bytes": content[:100]
|
||||
if content
|
||||
else b"", # First 100 bytes as fingerprint
|
||||
}
|
||||
|
||||
# Wait a moment to ensure different timestamps
|
||||
time.sleep(0.1)
|
||||
|
||||
# Force refresh - should create new files with different metadata
|
||||
await _create_tutorial_notebook(user_id, mock_session, force_refresh=True)
|
||||
|
||||
items_second = list(tutorial_data_dir.iterdir())
|
||||
assert len(items_second) == 1, (
|
||||
"Tutorial data directory should contain exactly one item after force refresh"
|
||||
)
|
||||
second_dir = items_second[0]
|
||||
|
||||
# Verify the structure is maintained after force refresh
|
||||
notebook_file_second = second_dir / "tutorial.ipynb"
|
||||
assert notebook_file_second.exists(), (
|
||||
f"tutorial.ipynb should exist in {second_dir} after force refresh"
|
||||
)
|
||||
assert notebook_file_second.is_file(), "tutorial.ipynb should be a file after force refresh"
|
||||
|
||||
data_folder_second = second_dir / "data"
|
||||
assert data_folder_second.exists(), (
|
||||
f"data subfolder should exist in {second_dir} after force refresh"
|
||||
)
|
||||
assert data_folder_second.is_dir(), "data should be a directory after force refresh"
|
||||
|
||||
data_items_second = list(data_folder_second.iterdir())
|
||||
assert len(data_items_second) > 0, (
|
||||
f"data folder should still contain files after force refresh, but found {len(data_items_second)} items"
|
||||
)
|
||||
|
||||
# Compare metadata to ensure files are actually different
|
||||
files_with_changed_metadata = 0
|
||||
|
||||
for file_path in second_dir.rglob("*"):
|
||||
if file_path.is_file():
|
||||
relative_path = file_path.relative_to(second_dir)
|
||||
relative_path_str = str(relative_path)
|
||||
|
||||
# File should exist from first creation
|
||||
assert relative_path_str in first_creation_metadata, (
|
||||
f"File {relative_path_str} missing from first creation"
|
||||
)
|
||||
|
||||
old_metadata = first_creation_metadata[relative_path_str]
|
||||
|
||||
# Get new metadata
|
||||
stat = file_path.stat()
|
||||
with open(file_path, "rb") as f:
|
||||
new_content = f.read()
|
||||
|
||||
new_metadata = {
|
||||
"mtime": stat.st_mtime,
|
||||
"size": stat.st_size,
|
||||
"hash": hashlib.md5(new_content).hexdigest(),
|
||||
"first_bytes": new_content[:100] if new_content else b"",
|
||||
}
|
||||
|
||||
# Check if any metadata changed (indicating file was refreshed)
|
||||
metadata_changed = (
|
||||
new_metadata["mtime"] > old_metadata["mtime"] # Newer modification time
|
||||
or new_metadata["hash"] != old_metadata["hash"] # Different content hash
|
||||
or new_metadata["size"] != old_metadata["size"] # Different file size
|
||||
or new_metadata["first_bytes"]
|
||||
!= old_metadata["first_bytes"] # Different content
|
||||
)
|
||||
|
||||
if metadata_changed:
|
||||
files_with_changed_metadata += 1
|
||||
|
||||
# Assert that force refresh actually updated files
|
||||
assert files_with_changed_metadata > 0, (
|
||||
f"Force refresh should have updated at least some files, but all {len(first_creation_metadata)} "
|
||||
f"files appear to have identical metadata. This suggests force refresh didn't work."
|
||||
)
|
||||
|
||||
mock_session.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_zip_url_accessibility(self):
|
||||
"""Test that the actual tutorial zip URL is accessible (integration test)."""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get(
|
||||
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify it's a valid zip file by checking headers
|
||||
assert response.headers.get("content-type") in [
|
||||
"application/zip",
|
||||
"application/octet-stream",
|
||||
"application/x-zip-compressed",
|
||||
] or response.content.startswith(b"PK") # Zip file signature
|
||||
|
||||
except Exception:
|
||||
pytest.skip("Network request failed or zip not available - skipping integration test")
|
||||
|
|
@ -46,6 +46,7 @@ dependencies = [
|
|||
"matplotlib>=3.8.3,<4",
|
||||
"networkx>=3.4.2,<4",
|
||||
"lancedb>=0.24.0,<1.0.0",
|
||||
"nbformat>=5.7.0,<6.0.0",
|
||||
"alembic>=1.13.3,<2",
|
||||
"pre-commit>=4.0.1,<5",
|
||||
"scikit-learn>=1.6.1,<2",
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -831,6 +831,7 @@ dependencies = [
|
|||
{ name = "limits" },
|
||||
{ name = "litellm" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "nbformat" },
|
||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "nltk" },
|
||||
|
|
@ -1012,6 +1013,7 @@ requires-dist = [
|
|||
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
|
||||
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
|
||||
{ name = "nbformat", specifier = ">=5.7.0,<6.0.0" },
|
||||
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
|
||||
{ name = "networkx", specifier = ">=3.4.2,<4" },
|
||||
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue