cognee/cognee/shared/cache.py
Daulet Amirkhanov f58ba86e7c
feat: add welcome tutorial notebook for new users (#1425)
<!-- .github/pull_request_template.md -->

## Description
<!-- 
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

Update default tutorial:
1. Use tutorial from [notebook_tutorial
branch](https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/tutorial.ipynb),
specifically - it's .zip version with all necessary data files
2. Use Jupyter Notebook `Notebook` abstractions to read, and map `ipynb`
into our Notebook model
3. Dynamically update starter notebook code blocks that reference
starter data files, and swap them with local paths to downloaded copies
4. Test coverage



| Before | After (storage backend = local) | After (s3) |
|--------|---------------------------------|------------|
| <img width="613" height="546" alt="Screenshot 2025-09-17 at 01 00 58"
src="https://github.com/user-attachments/assets/20b59021-96c1-4a83-977f-e064324bd758"
/> | <img width="1480" height="262" alt="Screenshot 2025-09-18 at 13 01
57"
src="https://github.com/user-attachments/assets/bd56ea78-7c6a-42e3-ae3f-4157da231b2d"
/> | <img width="1485" height="307" alt="Screenshot 2025-09-18 at 12 56
08"
src="https://github.com/user-attachments/assets/248ae720-4c78-445a-ba8b-8a2991ed3f80"
/> |



## File Replacements

### S3 Demo  

https://github.com/user-attachments/assets/bd46eec9-ef77-4f69-9ef0-e7d1612ff9b3

---

### Local FS Demo  

https://github.com/user-attachments/assets/8251cea0-81b3-4cac-a968-9576c358f334


## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Changes Made
<!-- List the specific changes made in this PR -->
- 
- 
- 

## Testing
<!-- Describe how you tested your changes -->

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
<!-- Link any related issues using "Fixes #issue_number" or "Relates to
#issue_number" -->

## Additional Notes
<!-- Add any additional notes, concerns, or context for reviewers -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-09-18 18:07:05 +02:00

346 lines
14 KiB
Python

"""
Storage-aware cache management utilities for Cognee.
This module provides cache functionality that works with both local and cloud storage
backends (like S3) through the StorageManager abstraction.
"""
import hashlib
import zipfile
import asyncio
from typing import Optional, Tuple
import aiohttp
import logging
from io import BytesIO
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
from cognee.infrastructure.files.storage.StorageManager import StorageManager
from cognee.shared.utils import create_secure_ssl_context
logger = logging.getLogger(__name__)
class StorageAwareCache:
"""
A cache manager that works with different storage backends (local, S3, etc.)
"""
def __init__(self, cache_subdir: str = "cache"):
"""
Initialize the cache manager.
Args:
cache_subdir: Subdirectory name within the system root for caching
"""
self.base_config = get_base_config()
# Since we're using cache_root_directory, don't add extra cache prefix
self.cache_base_path = ""
self.storage_manager: StorageManager = get_file_storage(
self.base_config.cache_root_directory
)
# Print absolute path
storage_path = self.storage_manager.storage.storage_path
if storage_path.startswith("s3://"):
absolute_path = storage_path # S3 paths are already absolute
else:
import os
absolute_path = os.path.abspath(storage_path)
logger.info(f"Storage manager absolute path: {absolute_path}")
async def get_cache_dir(self) -> str:
"""Get the base cache directory path."""
cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty
await self.storage_manager.ensure_directory_exists(cache_path)
return cache_path
async def get_cache_subdir(self, name: str) -> str:
"""Get a specific cache subdirectory."""
if self.cache_base_path:
cache_path = f"{self.cache_base_path}/{name}"
else:
cache_path = name
await self.storage_manager.ensure_directory_exists(cache_path)
# Return the absolute path based on storage system
if self.storage_manager.storage.storage_path.startswith("s3://"):
return cache_path
elif hasattr(self.storage_manager.storage, "storage_path"):
return f"{self.storage_manager.storage.storage_path}/{cache_path}"
else:
# Fallback for other storage types
return cache_path
async def delete_cache(self):
"""Delete the entire cache directory."""
logger.info("Deleting cache...")
try:
await self.storage_manager.remove_all(self.cache_base_path)
logger.info("✓ Cache deleted successfully!")
except Exception as e:
logger.error(f"Error deleting cache: {e}")
raise
async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool:
"""Check if cached content is valid for the given version/hash."""
version_file = f"{cache_dir}/version.txt"
if not await self.storage_manager.file_exists(version_file):
return False
try:
async with self.storage_manager.open(version_file, "r") as f:
cached_version = (await asyncio.to_thread(f.read)).strip()
return cached_version == version_or_hash
except Exception as e:
logger.debug(f"Error checking cache validity: {e}")
return False
async def _clear_cache(self, cache_dir: str) -> None:
"""Clear a cache directory."""
try:
await self.storage_manager.remove_all(cache_dir)
except Exception as e:
logger.debug(f"Error clearing cache directory {cache_dir}: {e}")
async def _check_remote_content_freshness(
self, url: str, cache_dir: str
) -> Tuple[bool, Optional[str]]:
"""
Check if remote content is fresher than cached version using HTTP headers.
Returns:
Tuple of (is_fresh: bool, new_identifier: Optional[str])
"""
try:
# Make a HEAD request to check headers without downloading
ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
response.raise_for_status()
# Try ETag first (most reliable)
etag = response.headers.get("ETag", "").strip('"')
last_modified = response.headers.get("Last-Modified", "")
# Use ETag if available, otherwise Last-Modified
remote_identifier = etag if etag else last_modified
if not remote_identifier:
logger.debug("No freshness headers available, cannot check for updates")
return True, None # Assume fresh if no headers
# Check cached identifier
identifier_file = f"{cache_dir}/content_id.txt"
if await self.storage_manager.file_exists(identifier_file):
async with self.storage_manager.open(identifier_file, "r") as f:
cached_identifier = (await asyncio.to_thread(f.read)).strip()
if cached_identifier == remote_identifier:
logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)")
return True, None
else:
logger.info(
f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)"
)
return False, remote_identifier
else:
# No cached identifier, treat as stale
return False, remote_identifier
except Exception as e:
logger.debug(f"Could not check remote freshness: {e}")
return True, None # Assume fresh if we can't check
async def download_and_extract_zip(
self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False
) -> str:
"""
Download a zip file and extract it to cache directory with content freshness checking.
Args:
url: URL to download zip file from
cache_subdir_name: Name of the cache subdirectory
version_or_hash: Version string or content hash for cache validation
force: If True, re-download even if already cached
Returns:
Path to the cached directory
"""
cache_dir = await self.get_cache_subdir(cache_subdir_name)
# Check if already cached and valid
if not force and await self._is_cache_valid(cache_dir, version_or_hash):
# Also check if remote content has changed
is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir)
if is_fresh:
logger.debug(f"Content already cached and fresh for version {version_or_hash}")
return cache_dir
else:
logger.info("Cached content is stale, updating...")
# Clear old cache if it exists
await self._clear_cache(cache_dir)
logger.info(f"Downloading content from {url}...")
# Download the zip file
zip_content = BytesIO()
etag = ""
last_modified = ""
ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response:
response.raise_for_status()
# Extract headers before consuming response
etag = response.headers.get("ETag", "").strip('"')
last_modified = response.headers.get("Last-Modified", "")
# Read the response content
async for chunk in response.content.iter_chunked(8192):
zip_content.write(chunk)
zip_content.seek(0)
# Extract the archive
await self.storage_manager.ensure_directory_exists(cache_dir)
# Extract files and store them using StorageManager
with zipfile.ZipFile(zip_content, "r") as zip_file:
for file_info in zip_file.infolist():
if file_info.is_dir():
# Create directory
dir_path = f"{cache_dir}/{file_info.filename}"
await self.storage_manager.ensure_directory_exists(dir_path)
else:
# Extract and store file
file_data = zip_file.read(file_info.filename)
file_path = f"{cache_dir}/{file_info.filename}"
await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True)
# Write version info for future cache validation
version_file = f"{cache_dir}/version.txt"
await self.storage_manager.store(version_file, version_or_hash, overwrite=True)
# Store content identifier from response headers for freshness checking
content_identifier = etag if etag else last_modified
if content_identifier:
identifier_file = f"{cache_dir}/content_id.txt"
await self.storage_manager.store(identifier_file, content_identifier, overwrite=True)
logger.debug(f"Stored content identifier: {content_identifier[:20]}...")
logger.info("✓ Content downloaded and cached successfully!")
return cache_dir
async def file_exists(self, file_path: str) -> bool:
"""Check if a file exists in cache storage."""
return await self.storage_manager.file_exists(file_path)
async def read_file(self, file_path: str, encoding: str = "utf-8"):
"""Read a file from cache storage."""
return self.storage_manager.open(file_path, encoding=encoding)
async def list_files(self, directory_path: str):
"""List files in a cache directory."""
try:
file_list = await self.storage_manager.list_files(directory_path)
# For S3 storage, convert relative paths to full S3 URLs
if self.storage_manager.storage.storage_path.startswith("s3://"):
full_paths = []
for file_path in file_list:
full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}"
full_paths.append(full_s3_path)
return full_paths
else:
# For local storage, return absolute paths
storage_path = self.storage_manager.storage.storage_path
if not storage_path.startswith("/"):
import os
storage_path = os.path.abspath(storage_path)
full_paths = []
for file_path in file_list:
if file_path.startswith("/"):
full_paths.append(file_path) # Already absolute
else:
full_paths.append(f"{storage_path}/{file_path}")
return full_paths
except Exception as e:
logger.debug(f"Error listing files in {directory_path}: {e}")
return []
# Convenience functions that maintain API compatibility
_cache_manager = None
def get_cache_manager() -> StorageAwareCache:
"""Get a singleton cache manager instance."""
global _cache_manager
if _cache_manager is None:
_cache_manager = StorageAwareCache()
return _cache_manager
def generate_content_hash(url: str, additional_data: str = "") -> str:
"""Generate a content hash from URL and optional additional data."""
content = f"{url}:{additional_data}"
return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability
# Async wrapper functions for backward compatibility
async def delete_cache():
"""Delete the Cognee cache directory."""
cache_manager = get_cache_manager()
await cache_manager.delete_cache()
async def get_cognee_cache_dir() -> str:
"""Get the base Cognee cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.get_cache_dir()
async def get_cache_subdir(name: str) -> str:
"""Get a specific cache subdirectory."""
cache_manager = get_cache_manager()
return await cache_manager.get_cache_subdir(name)
async def download_and_extract_zip(
url: str, cache_dir_name: str, version_or_hash: str, force: bool = False
) -> str:
"""Download a zip file and extract it to cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force)
async def get_tutorial_data_dir() -> str:
"""Get the tutorial data cache directory."""
return await get_cache_subdir("tutorial_data")
# Cache file operations
async def cache_file_exists(file_path: str) -> bool:
"""Check if a file exists in cache storage."""
cache_manager = get_cache_manager()
return await cache_manager.file_exists(file_path)
async def read_cache_file(file_path: str, encoding: str = "utf-8"):
"""Read a file from cache storage."""
cache_manager = get_cache_manager()
return await cache_manager.read_file(file_path, encoding)
async def list_cache_files(directory_path: str):
"""List files in a cache directory."""
cache_manager = get_cache_manager()
return await cache_manager.list_files(directory_path)