feat: add welcome tutorial notebook for new users (#1425)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> Update default tutorial: 1. Use tutorial from [notebook_tutorial branch](https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/tutorial.ipynb), specifically - it's .zip version with all necessary data files 2. Use Jupyter Notebook `Notebook` abstractions to read, and map `ipynb` into our Notebook model 3. Dynamically update starter notebook code blocks that reference starter data files, and swap them with local paths to downloaded copies 4. Test coverage | Before | After (storage backend = local) | After (s3) | |--------|---------------------------------|------------| | <img width="613" height="546" alt="Screenshot 2025-09-17 at 01 00 58" src="https://github.com/user-attachments/assets/20b59021-96c1-4a83-977f-e064324bd758" /> | <img width="1480" height="262" alt="Screenshot 2025-09-18 at 13 01 57" src="https://github.com/user-attachments/assets/bd56ea78-7c6a-42e3-ae3f-4157da231b2d" /> | <img width="1485" height="307" alt="Screenshot 2025-09-18 at 12 56 08" src="https://github.com/user-attachments/assets/248ae720-4c78-445a-ba8b-8a2991ed3f80" /> | ## File Replacements ### S3 Demo https://github.com/user-attachments/assets/bd46eec9-ef77-4f69-9ef0-e7d1612ff9b3 --- ### Local FS Demo https://github.com/user-attachments/assets/8251cea0-81b3-4cac-a968-9576c358f334 ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Changes Made <!-- List the specific changes made in this PR --> - - - ## Testing <!-- Describe how you tested your changes --> ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues <!-- Link any related issues using "Fixes #issue_number" or "Relates to #issue_number" --> ## Additional Notes <!-- Add any additional notes, concerns, or context for reviewers --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
bb124494c1
commit
f58ba86e7c
20 changed files with 1200 additions and 44 deletions
|
|
@ -47,6 +47,28 @@ BAML_LLM_API_VERSION=""
|
||||||
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
|
# DATA_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_data/'
|
||||||
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
|
# SYSTEM_ROOT_DIRECTORY='/Users/<user>/Desktop/cognee/.cognee_system/'
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# ☁️ Storage Backend Settings
|
||||||
|
################################################################################
|
||||||
|
# Configure storage backend (local filesystem or S3)
|
||||||
|
# STORAGE_BACKEND="local" # Default: uses local filesystem
|
||||||
|
#
|
||||||
|
# -- To switch to S3 storage, uncomment and fill these: ---------------------
|
||||||
|
# STORAGE_BACKEND="s3"
|
||||||
|
# STORAGE_BUCKET_NAME="your-bucket-name"
|
||||||
|
# AWS_REGION="us-east-1"
|
||||||
|
# AWS_ACCESS_KEY_ID="your-access-key"
|
||||||
|
# AWS_SECRET_ACCESS_KEY="your-secret-key"
|
||||||
|
#
|
||||||
|
# -- S3 Root Directories (optional) -----------------------------------------
|
||||||
|
# DATA_ROOT_DIRECTORY="s3://your-bucket/cognee/data"
|
||||||
|
# SYSTEM_ROOT_DIRECTORY="s3://your-bucket/cognee/system"
|
||||||
|
#
|
||||||
|
# -- Cache Directory (auto-configured for S3) -------------------------------
|
||||||
|
# When STORAGE_BACKEND=s3, cache automatically uses S3: s3://BUCKET/cognee/cache
|
||||||
|
# To override the automatic S3 cache location, uncomment:
|
||||||
|
# CACHE_ROOT_DIRECTORY="s3://your-bucket/cognee/cache"
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🗄️ Relational database settings
|
# 🗄️ Relational database settings
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -186,6 +186,7 @@ cognee/cache/
|
||||||
# Default cognee system directory, used in development
|
# Default cognee system directory, used in development
|
||||||
.cognee_system/
|
.cognee_system/
|
||||||
.data_storage/
|
.data_storage/
|
||||||
|
.cognee_cache/
|
||||||
.artifacts/
|
.artifacts/
|
||||||
.anon_id
|
.anon_id
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ class prune:
|
||||||
await _prune_data()
|
await _prune_data()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def prune_system(graph=True, vector=True, metadata=False):
|
async def prune_system(graph=True, vector=True, metadata=False, cache=True):
|
||||||
await _prune_system(graph, vector, metadata)
|
await _prune_system(graph, vector, metadata, cache)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
|
||||||
mark_sync_completed,
|
mark_sync_completed,
|
||||||
mark_sync_failed,
|
mark_sync_failed,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger("sync")
|
logger = get_logger("sync")
|
||||||
|
|
||||||
|
|
@ -583,7 +584,9 @@ async def _check_hashes_diff(
|
||||||
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
@ -630,7 +633,9 @@ async def _download_missing_files(
|
||||||
|
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
for file_hash in hashes_missing_on_local:
|
for file_hash in hashes_missing_on_local:
|
||||||
try:
|
try:
|
||||||
# Download file from cloud by hash
|
# Download file from cloud by hash
|
||||||
|
|
@ -749,7 +754,9 @@ async def _upload_missing_files(
|
||||||
|
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
for file_info in files_to_upload:
|
for file_info in files_to_upload:
|
||||||
try:
|
try:
|
||||||
file_dir = os.path.dirname(file_info.raw_data_location)
|
file_dir = os.path.dirname(file_info.raw_data_location)
|
||||||
|
|
@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
|
||||||
logger.info("Pruning cloud dataset to match local state")
|
logger.info("Pruning cloud dataset to match local state")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
|
||||||
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, json=payload, headers=headers) as response:
|
async with session.post(url, json=payload, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,27 @@ import pydantic
|
||||||
class BaseConfig(BaseSettings):
|
class BaseConfig(BaseSettings):
|
||||||
data_root_directory: str = get_absolute_path(".data_storage")
|
data_root_directory: str = get_absolute_path(".data_storage")
|
||||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||||
|
cache_root_directory: str = get_absolute_path(".cognee_cache")
|
||||||
monitoring_tool: object = Observer.LANGFUSE
|
monitoring_tool: object = Observer.LANGFUSE
|
||||||
|
|
||||||
@pydantic.model_validator(mode="after")
|
@pydantic.model_validator(mode="after")
|
||||||
def validate_paths(self):
|
def validate_paths(self):
|
||||||
|
# Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
|
||||||
|
# I'll remove this after we update documentation for S3 storage
|
||||||
|
# Auto-configure cache root directory for S3 storage if not explicitly set
|
||||||
|
storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
|
||||||
|
cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
|
||||||
|
|
||||||
|
if storage_backend == "s3" and not cache_root_env:
|
||||||
|
# Auto-generate S3 cache path when using S3 storage
|
||||||
|
bucket_name = os.getenv("STORAGE_BUCKET_NAME")
|
||||||
|
if bucket_name:
|
||||||
|
self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
|
||||||
|
|
||||||
# Require absolute paths for root directories
|
# Require absolute paths for root directories
|
||||||
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
||||||
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
||||||
|
self.cache_root_directory = ensure_absolute_path(self.cache_root_directory)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||||
|
|
@ -31,6 +45,7 @@ class BaseConfig(BaseSettings):
|
||||||
"data_root_directory": self.data_root_directory,
|
"data_root_directory": self.data_root_directory,
|
||||||
"system_root_directory": self.system_root_directory,
|
"system_root_directory": self.system_root_directory,
|
||||||
"monitoring_tool": self.monitoring_tool,
|
"monitoring_tool": self.monitoring_tool,
|
||||||
|
"cache_root_directory": self.cache_root_directory,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import aiohttp
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
"""Get or create an aiohttp session."""
|
"""Get or create an aiohttp session."""
|
||||||
if self._session is None or self._session.closed:
|
if self._session is None or self._session.closed:
|
||||||
self._session = aiohttp.ClientSession()
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
self._session = aiohttp.ClientSession(connector=connector)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
|
||||||
embedding_rate_limit_async,
|
embedding_rate_limit_async,
|
||||||
embedding_sleep_and_retry_async,
|
embedding_sleep_and_retry_async,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger("OllamaEmbeddingEngine")
|
logger = get_logger("OllamaEmbeddingEngine")
|
||||||
|
|
||||||
|
|
@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||||
) as response:
|
) as response:
|
||||||
|
|
|
||||||
|
|
@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
|
||||||
if os.path.exists(full_file_path):
|
if os.path.exists(full_file_path):
|
||||||
os.remove(full_file_path)
|
os.remove(full_file_path)
|
||||||
|
|
||||||
|
def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||||
|
|
||||||
|
if directory_path:
|
||||||
|
full_directory_path = os.path.join(parsed_storage_path, directory_path)
|
||||||
|
else:
|
||||||
|
full_directory_path = parsed_storage_path
|
||||||
|
|
||||||
|
directory_pathlib = Path(full_directory_path)
|
||||||
|
|
||||||
|
if not directory_pathlib.exists() or not directory_pathlib.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Use rglob for recursive search
|
||||||
|
for file_path in directory_pathlib.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
# Get relative path from storage root
|
||||||
|
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||||
|
# Normalize path separators for consistency
|
||||||
|
relative_path = relative_path.replace(os.sep, "/")
|
||||||
|
files.append(relative_path)
|
||||||
|
else:
|
||||||
|
# Use iterdir for just immediate directory
|
||||||
|
for file_path in directory_pathlib.iterdir():
|
||||||
|
if file_path.is_file():
|
||||||
|
# Get relative path from storage root
|
||||||
|
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
||||||
|
# Normalize path separators for consistency
|
||||||
|
relative_path = relative_path.replace(os.sep, "/")
|
||||||
|
files.append(relative_path)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
def remove_all(self, tree_path: str = None):
|
def remove_all(self, tree_path: str = None):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -155,21 +155,19 @@ class S3FileStorage(Storage):
|
||||||
"""
|
"""
|
||||||
Ensure that the specified directory exists, creating it if necessary.
|
Ensure that the specified directory exists, creating it if necessary.
|
||||||
|
|
||||||
If the directory already exists, no action is taken.
|
For S3 storage, this is a no-op since directories are created implicitly
|
||||||
|
when files are written to paths. S3 doesn't have actual directories,
|
||||||
|
just object keys with prefixes that appear as directories.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- directory_path (str): The path of the directory to check or create.
|
- directory_path (str): The path of the directory to check or create.
|
||||||
"""
|
"""
|
||||||
if not directory_path.strip():
|
# In S3, directories don't exist as separate entities - they're just prefixes
|
||||||
directory_path = self.storage_path.replace("s3://", "")
|
# When you write a file to s3://bucket/path/to/file.txt, the "directories"
|
||||||
|
# path/ and path/to/ are implicitly created. No explicit action needed.
|
||||||
def ensure_directory():
|
pass
|
||||||
if not self.s3.exists(directory_path):
|
|
||||||
self.s3.makedirs(directory_path, exist_ok=True)
|
|
||||||
|
|
||||||
await run_async(ensure_directory)
|
|
||||||
|
|
||||||
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||||
"""
|
"""
|
||||||
|
|
@ -213,6 +211,55 @@ class S3FileStorage(Storage):
|
||||||
|
|
||||||
await run_async(remove_file)
|
await run_async(remove_file)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
|
||||||
|
def list_files_sync():
|
||||||
|
if directory_path:
|
||||||
|
# Combine storage path with directory path
|
||||||
|
full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
|
||||||
|
else:
|
||||||
|
full_path = self.storage_path.replace("s3://", "")
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
# Use ** for recursive search
|
||||||
|
pattern = f"{full_path}/**"
|
||||||
|
else:
|
||||||
|
# Just files in the immediate directory
|
||||||
|
pattern = f"{full_path}/*"
|
||||||
|
|
||||||
|
# Use s3fs glob to find files
|
||||||
|
try:
|
||||||
|
all_paths = self.s3.glob(pattern)
|
||||||
|
# Filter to only files (not directories)
|
||||||
|
files = [path for path in all_paths if self.s3.isfile(path)]
|
||||||
|
|
||||||
|
# Convert back to relative paths from storage root
|
||||||
|
storage_prefix = self.storage_path.replace("s3://", "")
|
||||||
|
relative_files = []
|
||||||
|
for file_path in files:
|
||||||
|
if file_path.startswith(storage_prefix):
|
||||||
|
relative_path = file_path[len(storage_prefix) :].lstrip("/")
|
||||||
|
relative_files.append(relative_path)
|
||||||
|
|
||||||
|
return relative_files
|
||||||
|
except Exception:
|
||||||
|
# If directory doesn't exist or other error, return empty list
|
||||||
|
return []
|
||||||
|
|
||||||
|
return await run_async(list_files_sync)
|
||||||
|
|
||||||
async def remove_all(self, tree_path: str):
|
async def remove_all(self, tree_path: str):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,24 @@ class StorageManager:
|
||||||
else:
|
else:
|
||||||
return self.storage.remove(file_path)
|
return self.storage.remove(file_path)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all files in the specified directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- directory_path (str): The directory path to list files from
|
||||||
|
- recursive (bool): If True, list files recursively in subdirectories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- list[str]: List of file paths relative to the storage root
|
||||||
|
"""
|
||||||
|
if inspect.iscoroutinefunction(self.storage.list_files):
|
||||||
|
return await self.storage.list_files(directory_path, recursive)
|
||||||
|
else:
|
||||||
|
return self.storage.list_files(directory_path, recursive)
|
||||||
|
|
||||||
async def remove_all(self, tree_path: str = None):
|
async def remove_all(self, tree_path: str = None):
|
||||||
"""
|
"""
|
||||||
Remove an entire directory tree at the specified path, including all files and
|
Remove an entire directory tree at the specified path, including all files and
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from cognee.modules.cloud.exceptions import CloudConnectionError
|
from cognee.modules.cloud.exceptions import CloudConnectionError
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
|
|
||||||
async def check_api_key(auth_token: str):
|
async def check_api_key(auth_token: str):
|
||||||
|
|
@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
|
||||||
headers = {"X-Api-Key": auth_token}
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(url, headers=headers) as response:
|
async with session.post(url, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.shared.cache import delete_cache
|
||||||
|
|
||||||
|
|
||||||
async def prune_system(graph=True, vector=True, metadata=True):
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||||
if graph:
|
if graph:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.delete_graph()
|
await graph_engine.delete_graph()
|
||||||
|
|
@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
|
||||||
if metadata:
|
if metadata:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
await db_engine.delete_database()
|
await db_engine.delete_database()
|
||||||
|
|
||||||
|
if cache:
|
||||||
|
await delete_cache()
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session
|
||||||
from ..models.Notebook import Notebook, NotebookCell
|
from ..models.Notebook import Notebook, NotebookCell
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_tutorial_notebook(
|
||||||
|
user_id: UUID, session: AsyncSession, force_refresh: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create the default tutorial notebook for new users.
|
||||||
|
Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
|
||||||
|
"""
|
||||||
|
TUTORIAL_ZIP_URL = (
|
||||||
|
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create notebook from remote zip file (includes notebook + data files)
|
||||||
|
notebook = await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url=TUTORIAL_ZIP_URL,
|
||||||
|
owner_id=user_id,
|
||||||
|
notebook_filename="tutorial.ipynb",
|
||||||
|
name="Python Development with Cognee Tutorial 🧠",
|
||||||
|
deletable=False,
|
||||||
|
force=force_refresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to session and commit
|
||||||
|
session.add(notebook)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
@with_async_session
|
||||||
async def create_notebook(
|
async def create_notebook(
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,24 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Literal
|
import nbformat
|
||||||
|
import asyncio
|
||||||
|
from nbformat.notebooknode import NotebookNode
|
||||||
|
from typing import List, Literal, Optional, cast, Tuple
|
||||||
from uuid import uuid4, UUID as UUID_t
|
from uuid import uuid4, UUID as UUID_t
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
||||||
from sqlalchemy.orm import mapped_column, Mapped
|
from sqlalchemy.orm import mapped_column, Mapped
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
from cognee.shared.cache import (
|
||||||
|
download_and_extract_zip,
|
||||||
|
get_tutorial_data_dir,
|
||||||
|
generate_content_hash,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
|
||||||
|
|
||||||
class NotebookCell(BaseModel):
|
class NotebookCell(BaseModel):
|
||||||
|
|
@ -51,3 +62,197 @@ class Notebook(Base):
|
||||||
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_ipynb_zip_url(
|
||||||
|
cls,
|
||||||
|
zip_url: str,
|
||||||
|
owner_id: UUID_t,
|
||||||
|
notebook_filename: str = "tutorial.ipynb",
|
||||||
|
name: Optional[str] = None,
|
||||||
|
deletable: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
) -> "Notebook":
|
||||||
|
"""
|
||||||
|
Create a Notebook instance from a remote zip file containing notebook + data files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_url: Remote URL to fetch the .zip file from
|
||||||
|
owner_id: UUID of the notebook owner
|
||||||
|
notebook_filename: Name of the .ipynb file within the zip
|
||||||
|
name: Optional custom name for the notebook
|
||||||
|
deletable: Whether the notebook can be deleted
|
||||||
|
force: If True, re-download even if already cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Notebook instance
|
||||||
|
"""
|
||||||
|
# Generate a cache key based on the zip URL
|
||||||
|
content_hash = generate_content_hash(zip_url, notebook_filename)
|
||||||
|
|
||||||
|
# Download and extract the zip file to tutorial_data/{content_hash}
|
||||||
|
try:
|
||||||
|
extracted_cache_dir = await download_and_extract_zip(
|
||||||
|
url=zip_url,
|
||||||
|
cache_dir_name=f"tutorial_data/{content_hash}",
|
||||||
|
version_or_hash=content_hash,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
|
||||||
|
|
||||||
|
# Use cache system to access the notebook file
|
||||||
|
from cognee.shared.cache import cache_file_exists, read_cache_file
|
||||||
|
|
||||||
|
notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
|
||||||
|
|
||||||
|
# Check if the notebook file exists in cache
|
||||||
|
if not await cache_file_exists(notebook_file_path):
|
||||||
|
raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
|
||||||
|
|
||||||
|
# Read and parse the notebook using cache system
|
||||||
|
async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
|
||||||
|
notebook_content = await asyncio.to_thread(f.read)
|
||||||
|
notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
|
||||||
|
|
||||||
|
# Update file paths in notebook cells to point to actual cached data files
|
||||||
|
await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
|
||||||
|
|
||||||
|
return notebook
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Update file paths in code cells to use actual cached data files.
|
||||||
|
Works with both local filesystem and S3 storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook: Parsed Notebook instance with cells to update
|
||||||
|
cache_dir: Path to the cached tutorial directory containing data files
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from cognee.shared.cache import list_cache_files, cache_file_exists
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Look for data files in the data subdirectory
|
||||||
|
data_dir = f"{cache_dir}/data"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all data files in the cache directory using cache system
|
||||||
|
data_files = {}
|
||||||
|
if await cache_file_exists(data_dir):
|
||||||
|
file_list = await list_cache_files(data_dir)
|
||||||
|
else:
|
||||||
|
file_list = []
|
||||||
|
|
||||||
|
for file_path in file_list:
|
||||||
|
# Extract just the filename
|
||||||
|
filename = file_path.split("/")[-1]
|
||||||
|
# Use the file path as provided by cache system
|
||||||
|
data_files[filename] = file_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If we can't list files, skip updating paths
|
||||||
|
logger.error(f"Error listing data files in {data_dir}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pattern to match file://data/filename patterns in code cells
|
||||||
|
file_pattern = r'"file://data/([^"]+)"'
|
||||||
|
|
||||||
|
def replace_path(match):
|
||||||
|
filename = match.group(1)
|
||||||
|
if filename in data_files:
|
||||||
|
file_path = data_files[filename]
|
||||||
|
# For local filesystem, preserve file:// prefix
|
||||||
|
if not file_path.startswith("s3://"):
|
||||||
|
return f'"file://{file_path}"'
|
||||||
|
else:
|
||||||
|
# For S3, return the S3 URL as-is
|
||||||
|
return f'"{file_path}"'
|
||||||
|
return match.group(0) # Keep original if file not found
|
||||||
|
|
||||||
|
# Update only code cells
|
||||||
|
updated_cells = 0
|
||||||
|
for cell in notebook.cells:
|
||||||
|
if cell.type == "code":
|
||||||
|
original_content = cell.content
|
||||||
|
# Update file paths in the cell content
|
||||||
|
cell.content = re.sub(file_pattern, replace_path, cell.content)
|
||||||
|
if original_content != cell.content:
|
||||||
|
updated_cells += 1
|
||||||
|
|
||||||
|
# Log summary of updates (useful for monitoring)
|
||||||
|
if updated_cells > 0:
|
||||||
|
logger.info(f"Updated file paths in {updated_cells} notebook cells")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_ipynb_string(
|
||||||
|
cls,
|
||||||
|
notebook_content: str,
|
||||||
|
owner_id: UUID_t,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
deletable: bool = True,
|
||||||
|
) -> "Notebook":
|
||||||
|
"""
|
||||||
|
Create a Notebook instance from Jupyter notebook string content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notebook_content: Raw Jupyter notebook content as string
|
||||||
|
owner_id: UUID of the notebook owner
|
||||||
|
name: Optional custom name for the notebook
|
||||||
|
deletable: Whether the notebook can be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Notebook instance ready to be saved to database
|
||||||
|
"""
|
||||||
|
# Parse and validate the Jupyter notebook using nbformat
|
||||||
|
# Note: nbformat.reads() has loose typing, so we cast to NotebookNode
|
||||||
|
jupyter_nb = cast(
|
||||||
|
NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert Jupyter cells to NotebookCell objects
|
||||||
|
cells = []
|
||||||
|
for jupyter_cell in jupyter_nb.cells:
|
||||||
|
# Each cell is also a NotebookNode with dynamic attributes
|
||||||
|
cell = cast(NotebookNode, jupyter_cell)
|
||||||
|
# Skip raw cells as they're not supported in our model
|
||||||
|
if cell.cell_type == "raw":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the source content
|
||||||
|
content = cell.source
|
||||||
|
|
||||||
|
# Generate a name based on content or cell index
|
||||||
|
cell_name = cls._generate_cell_name(cell)
|
||||||
|
|
||||||
|
# Map cell types (jupyter uses "code"/"markdown", we use same)
|
||||||
|
cell_type = "code" if cell.cell_type == "code" else "markdown"
|
||||||
|
|
||||||
|
cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
|
||||||
|
|
||||||
|
# Extract notebook name from metadata if not provided
|
||||||
|
if name is None:
|
||||||
|
kernelspec = jupyter_nb.metadata.get("kernelspec", {})
|
||||||
|
name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
|
||||||
|
|
||||||
|
return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
|
||||||
|
"""Generate a meaningful name for a notebook cell using nbformat cell."""
|
||||||
|
if jupyter_cell.cell_type == "markdown":
|
||||||
|
# Try to extract a title from markdown headers
|
||||||
|
content = jupyter_cell.source
|
||||||
|
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
if lines and lines[0].startswith("#"):
|
||||||
|
# Extract header text, clean it up
|
||||||
|
header = lines[0].lstrip("#").strip()
|
||||||
|
return header[:50] if len(header) > 50 else header
|
||||||
|
else:
|
||||||
|
return "Markdown Cell"
|
||||||
|
else:
|
||||||
|
return "Code Cell"
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from uuid import uuid4
|
from uuid import UUID, uuid4
|
||||||
from fastapi_users.exceptions import UserAlreadyExists
|
from fastapi_users.exceptions import UserAlreadyExists
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.notebooks.methods import create_notebook
|
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||||
from cognee.modules.notebooks.models.Notebook import NotebookCell
|
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||||
from cognee.modules.users.exceptions import TenantNotFoundError
|
from cognee.modules.users.exceptions import TenantNotFoundError
|
||||||
from cognee.modules.users.get_user_manager import get_user_manager_context
|
from cognee.modules.users.get_user_manager import get_user_manager_context
|
||||||
from cognee.modules.users.get_user_db import get_user_db_context
|
from cognee.modules.users.get_user_db import get_user_db_context
|
||||||
|
|
@ -60,26 +61,7 @@ async def create_user(
|
||||||
if auto_login:
|
if auto_login:
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
|
|
||||||
await create_notebook(
|
await _create_tutorial_notebook(user.id, session)
|
||||||
user_id=user.id,
|
|
||||||
notebook_name="Welcome to cognee 🧠",
|
|
||||||
cells=[
|
|
||||||
NotebookCell(
|
|
||||||
id=uuid4(),
|
|
||||||
name="Welcome",
|
|
||||||
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
|
|
||||||
type="markdown",
|
|
||||||
),
|
|
||||||
NotebookCell(
|
|
||||||
id=uuid4(),
|
|
||||||
name="Example",
|
|
||||||
content="",
|
|
||||||
type="code",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
deletable=False,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
except UserAlreadyExists as error:
|
except UserAlreadyExists as error:
|
||||||
|
|
|
||||||
346
cognee/shared/cache.py
Normal file
346
cognee/shared/cache.py
Normal file
|
|
@ -0,0 +1,346 @@
|
||||||
|
"""
|
||||||
|
Storage-aware cache management utilities for Cognee.
|
||||||
|
|
||||||
|
This module provides cache functionality that works with both local and cloud storage
|
||||||
|
backends (like S3) through the StorageManager abstraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import zipfile
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
||||||
|
from cognee.infrastructure.files.storage.StorageManager import StorageManager
|
||||||
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageAwareCache:
|
||||||
|
"""
|
||||||
|
A cache manager that works with different storage backends (local, S3, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_subdir: str = "cache"):
|
||||||
|
"""
|
||||||
|
Initialize the cache manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_subdir: Subdirectory name within the system root for caching
|
||||||
|
"""
|
||||||
|
self.base_config = get_base_config()
|
||||||
|
# Since we're using cache_root_directory, don't add extra cache prefix
|
||||||
|
self.cache_base_path = ""
|
||||||
|
self.storage_manager: StorageManager = get_file_storage(
|
||||||
|
self.base_config.cache_root_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print absolute path
|
||||||
|
storage_path = self.storage_manager.storage.storage_path
|
||||||
|
if storage_path.startswith("s3://"):
|
||||||
|
absolute_path = storage_path # S3 paths are already absolute
|
||||||
|
else:
|
||||||
|
import os
|
||||||
|
|
||||||
|
absolute_path = os.path.abspath(storage_path)
|
||||||
|
logger.info(f"Storage manager absolute path: {absolute_path}")
|
||||||
|
|
||||||
|
async def get_cache_dir(self) -> str:
|
||||||
|
"""Get the base cache directory path."""
|
||||||
|
cache_path = self.cache_base_path or "." # Use "." for root when cache_base_path is empty
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
async def get_cache_subdir(self, name: str) -> str:
|
||||||
|
"""Get a specific cache subdirectory."""
|
||||||
|
if self.cache_base_path:
|
||||||
|
cache_path = f"{self.cache_base_path}/{name}"
|
||||||
|
else:
|
||||||
|
cache_path = name
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_path)
|
||||||
|
|
||||||
|
# Return the absolute path based on storage system
|
||||||
|
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||||
|
return cache_path
|
||||||
|
elif hasattr(self.storage_manager.storage, "storage_path"):
|
||||||
|
return f"{self.storage_manager.storage.storage_path}/{cache_path}"
|
||||||
|
else:
|
||||||
|
# Fallback for other storage types
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
async def delete_cache(self):
|
||||||
|
"""Delete the entire cache directory."""
|
||||||
|
logger.info("Deleting cache...")
|
||||||
|
try:
|
||||||
|
await self.storage_manager.remove_all(self.cache_base_path)
|
||||||
|
logger.info("✓ Cache deleted successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting cache: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _is_cache_valid(self, cache_dir: str, version_or_hash: str) -> bool:
|
||||||
|
"""Check if cached content is valid for the given version/hash."""
|
||||||
|
version_file = f"{cache_dir}/version.txt"
|
||||||
|
|
||||||
|
if not await self.storage_manager.file_exists(version_file):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.storage_manager.open(version_file, "r") as f:
|
||||||
|
cached_version = (await asyncio.to_thread(f.read)).strip()
|
||||||
|
return cached_version == version_or_hash
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error checking cache validity: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _clear_cache(self, cache_dir: str) -> None:
|
||||||
|
"""Clear a cache directory."""
|
||||||
|
try:
|
||||||
|
await self.storage_manager.remove_all(cache_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error clearing cache directory {cache_dir}: {e}")
|
||||||
|
|
||||||
|
async def _check_remote_content_freshness(
|
||||||
|
self, url: str, cache_dir: str
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Check if remote content is fresher than cached version using HTTP headers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_fresh: bool, new_identifier: Optional[str])
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Make a HEAD request to check headers without downloading
|
||||||
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
async with session.head(url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Try ETag first (most reliable)
|
||||||
|
etag = response.headers.get("ETag", "").strip('"')
|
||||||
|
last_modified = response.headers.get("Last-Modified", "")
|
||||||
|
|
||||||
|
# Use ETag if available, otherwise Last-Modified
|
||||||
|
remote_identifier = etag if etag else last_modified
|
||||||
|
|
||||||
|
if not remote_identifier:
|
||||||
|
logger.debug("No freshness headers available, cannot check for updates")
|
||||||
|
return True, None # Assume fresh if no headers
|
||||||
|
|
||||||
|
# Check cached identifier
|
||||||
|
identifier_file = f"{cache_dir}/content_id.txt"
|
||||||
|
if await self.storage_manager.file_exists(identifier_file):
|
||||||
|
async with self.storage_manager.open(identifier_file, "r") as f:
|
||||||
|
cached_identifier = (await asyncio.to_thread(f.read)).strip()
|
||||||
|
if cached_identifier == remote_identifier:
|
||||||
|
logger.debug(f"Content is fresh (identifier: {remote_identifier[:20]}...)")
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Content has changed (old: {cached_identifier[:20]}..., new: {remote_identifier[:20]}...)"
|
||||||
|
)
|
||||||
|
return False, remote_identifier
|
||||||
|
else:
|
||||||
|
# No cached identifier, treat as stale
|
||||||
|
return False, remote_identifier
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not check remote freshness: {e}")
|
||||||
|
return True, None # Assume fresh if we can't check
|
||||||
|
|
||||||
|
async def download_and_extract_zip(
|
||||||
|
self, url: str, cache_subdir_name: str, version_or_hash: str, force: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Download a zip file and extract it to cache directory with content freshness checking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to download zip file from
|
||||||
|
cache_subdir_name: Name of the cache subdirectory
|
||||||
|
version_or_hash: Version string or content hash for cache validation
|
||||||
|
force: If True, re-download even if already cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the cached directory
|
||||||
|
"""
|
||||||
|
cache_dir = await self.get_cache_subdir(cache_subdir_name)
|
||||||
|
|
||||||
|
# Check if already cached and valid
|
||||||
|
if not force and await self._is_cache_valid(cache_dir, version_or_hash):
|
||||||
|
# Also check if remote content has changed
|
||||||
|
is_fresh, new_identifier = await self._check_remote_content_freshness(url, cache_dir)
|
||||||
|
if is_fresh:
|
||||||
|
logger.debug(f"Content already cached and fresh for version {version_or_hash}")
|
||||||
|
return cache_dir
|
||||||
|
else:
|
||||||
|
logger.info("Cached content is stale, updating...")
|
||||||
|
|
||||||
|
# Clear old cache if it exists
|
||||||
|
await self._clear_cache(cache_dir)
|
||||||
|
|
||||||
|
logger.info(f"Downloading content from {url}...")
|
||||||
|
|
||||||
|
# Download the zip file
|
||||||
|
zip_content = BytesIO()
|
||||||
|
etag = ""
|
||||||
|
last_modified = ""
|
||||||
|
ssl_context = create_secure_ssl_context()
|
||||||
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Extract headers before consuming response
|
||||||
|
etag = response.headers.get("ETag", "").strip('"')
|
||||||
|
last_modified = response.headers.get("Last-Modified", "")
|
||||||
|
|
||||||
|
# Read the response content
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
zip_content.write(chunk)
|
||||||
|
zip_content.seek(0)
|
||||||
|
|
||||||
|
# Extract the archive
|
||||||
|
await self.storage_manager.ensure_directory_exists(cache_dir)
|
||||||
|
|
||||||
|
# Extract files and store them using StorageManager
|
||||||
|
with zipfile.ZipFile(zip_content, "r") as zip_file:
|
||||||
|
for file_info in zip_file.infolist():
|
||||||
|
if file_info.is_dir():
|
||||||
|
# Create directory
|
||||||
|
dir_path = f"{cache_dir}/{file_info.filename}"
|
||||||
|
await self.storage_manager.ensure_directory_exists(dir_path)
|
||||||
|
else:
|
||||||
|
# Extract and store file
|
||||||
|
file_data = zip_file.read(file_info.filename)
|
||||||
|
file_path = f"{cache_dir}/{file_info.filename}"
|
||||||
|
await self.storage_manager.store(file_path, BytesIO(file_data), overwrite=True)
|
||||||
|
|
||||||
|
# Write version info for future cache validation
|
||||||
|
version_file = f"{cache_dir}/version.txt"
|
||||||
|
await self.storage_manager.store(version_file, version_or_hash, overwrite=True)
|
||||||
|
|
||||||
|
# Store content identifier from response headers for freshness checking
|
||||||
|
content_identifier = etag if etag else last_modified
|
||||||
|
|
||||||
|
if content_identifier:
|
||||||
|
identifier_file = f"{cache_dir}/content_id.txt"
|
||||||
|
await self.storage_manager.store(identifier_file, content_identifier, overwrite=True)
|
||||||
|
logger.debug(f"Stored content identifier: {content_identifier[:20]}...")
|
||||||
|
|
||||||
|
logger.info("✓ Content downloaded and cached successfully!")
|
||||||
|
return cache_dir
|
||||||
|
|
||||||
|
async def file_exists(self, file_path: str) -> bool:
|
||||||
|
"""Check if a file exists in cache storage."""
|
||||||
|
return await self.storage_manager.file_exists(file_path)
|
||||||
|
|
||||||
|
async def read_file(self, file_path: str, encoding: str = "utf-8"):
|
||||||
|
"""Read a file from cache storage."""
|
||||||
|
return self.storage_manager.open(file_path, encoding=encoding)
|
||||||
|
|
||||||
|
async def list_files(self, directory_path: str):
|
||||||
|
"""List files in a cache directory."""
|
||||||
|
try:
|
||||||
|
file_list = await self.storage_manager.list_files(directory_path)
|
||||||
|
|
||||||
|
# For S3 storage, convert relative paths to full S3 URLs
|
||||||
|
if self.storage_manager.storage.storage_path.startswith("s3://"):
|
||||||
|
full_paths = []
|
||||||
|
for file_path in file_list:
|
||||||
|
full_s3_path = f"{self.storage_manager.storage.storage_path}/{file_path}"
|
||||||
|
full_paths.append(full_s3_path)
|
||||||
|
return full_paths
|
||||||
|
else:
|
||||||
|
# For local storage, return absolute paths
|
||||||
|
storage_path = self.storage_manager.storage.storage_path
|
||||||
|
if not storage_path.startswith("/"):
|
||||||
|
import os
|
||||||
|
|
||||||
|
storage_path = os.path.abspath(storage_path)
|
||||||
|
|
||||||
|
full_paths = []
|
||||||
|
for file_path in file_list:
|
||||||
|
if file_path.startswith("/"):
|
||||||
|
full_paths.append(file_path) # Already absolute
|
||||||
|
else:
|
||||||
|
full_paths.append(f"{storage_path}/{file_path}")
|
||||||
|
return full_paths
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error listing files in {directory_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions that maintain API compatibility
|
||||||
|
_cache_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_manager() -> StorageAwareCache:
|
||||||
|
"""Get a singleton cache manager instance."""
|
||||||
|
global _cache_manager
|
||||||
|
if _cache_manager is None:
|
||||||
|
_cache_manager = StorageAwareCache()
|
||||||
|
return _cache_manager
|
||||||
|
|
||||||
|
|
||||||
|
def generate_content_hash(url: str, additional_data: str = "") -> str:
|
||||||
|
"""Generate a content hash from URL and optional additional data."""
|
||||||
|
content = f"{url}:{additional_data}"
|
||||||
|
return hashlib.md5(content.encode()).hexdigest()[:12] # Short hash for readability
|
||||||
|
|
||||||
|
|
||||||
|
# Async wrapper functions for backward compatibility
|
||||||
|
async def delete_cache():
|
||||||
|
"""Delete the Cognee cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
await cache_manager.delete_cache()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cognee_cache_dir() -> str:
|
||||||
|
"""Get the base Cognee cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.get_cache_dir()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cache_subdir(name: str) -> str:
|
||||||
|
"""Get a specific cache subdirectory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.get_cache_subdir(name)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_and_extract_zip(
|
||||||
|
url: str, cache_dir_name: str, version_or_hash: str, force: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Download a zip file and extract it to cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.download_and_extract_zip(url, cache_dir_name, version_or_hash, force)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tutorial_data_dir() -> str:
|
||||||
|
"""Get the tutorial data cache directory."""
|
||||||
|
return await get_cache_subdir("tutorial_data")
|
||||||
|
|
||||||
|
|
||||||
|
# Cache file operations
|
||||||
|
async def cache_file_exists(file_path: str) -> bool:
|
||||||
|
"""Check if a file exists in cache storage."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.file_exists(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_cache_file(file_path: str, encoding: str = "utf-8"):
|
||||||
|
"""Read a file from cache storage."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.read_file(file_path, encoding)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_cache_files(directory_path: str):
|
||||||
|
"""List files in a cache directory."""
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
return await cache_manager.list_files(directory_path)
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""This module contains utility functions for the cognee."""
|
"""This module contains utility functions for the cognee."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
@ -18,6 +19,17 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
proxy_url = "https://test.prometh.ai"
|
proxy_url = "https://test.prometh.ai"
|
||||||
|
|
||||||
|
|
||||||
|
def create_secure_ssl_context() -> ssl.SSLContext:
|
||||||
|
"""
|
||||||
|
Create a secure SSL context.
|
||||||
|
|
||||||
|
By default, use the system's certificate store.
|
||||||
|
If users report SSL issues, I'm keeping this open in case we need to switch to:
|
||||||
|
ssl.create_default_context(cafile=certifi.where())
|
||||||
|
"""
|
||||||
|
return ssl.create_default_context()
|
||||||
|
|
||||||
|
|
||||||
def get_entities(tagged_tokens):
|
def get_entities(tagged_tokens):
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,399 @@
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from uuid import uuid4
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from pathlib import Path
|
||||||
|
import zipfile
|
||||||
|
from cognee.shared.cache import get_tutorial_data_dir
|
||||||
|
|
||||||
|
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
||||||
|
from cognee.modules.notebooks.models.Notebook import Notebook
|
||||||
|
import cognee
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level fixtures available to all test classes
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
"""Mock database session."""
|
||||||
|
session = AsyncMock(spec=AsyncSession)
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_jupyter_notebook():
|
||||||
|
"""Sample Jupyter notebook content for testing."""
|
||||||
|
return {
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["# Tutorial Introduction\n", "\n", "This is a tutorial notebook."],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": None,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": ["import cognee\n", "print('Hello, Cognee!')"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["## Step 1: Data Ingestion\n", "\n", "Let's add some data."],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": None,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": ["# Add your data here\n", "# await cognee.add('data.txt')"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"metadata": {},
|
||||||
|
"source": ["This is a raw cell that should be skipped"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestTutorialNotebookCreation:
|
||||||
|
"""Test cases for tutorial notebook creation functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_success(self, sample_jupyter_notebook):
|
||||||
|
"""Test successful creation of notebook from JSON string."""
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(
|
||||||
|
notebook_content=notebook_json, owner_id=user_id, name="String Test Notebook"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert notebook.owner_id == user_id
|
||||||
|
assert notebook.name == "String Test Notebook"
|
||||||
|
assert len(notebook.cells) == 4 # Should skip the raw cell
|
||||||
|
assert notebook.cells[0].type == "markdown"
|
||||||
|
assert notebook.cells[1].type == "code"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_cell_name_generation(self, sample_jupyter_notebook):
|
||||||
|
"""Test that cell names are generated correctly from markdown headers."""
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
# Check markdown header extraction
|
||||||
|
assert notebook.cells[0].name == "Tutorial Introduction"
|
||||||
|
assert notebook.cells[2].name == "Step 1: Data Ingestion"
|
||||||
|
|
||||||
|
# Check code cell naming
|
||||||
|
assert notebook.cells[1].name == "Code Cell"
|
||||||
|
assert notebook.cells[3].name == "Code Cell"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_with_default_name(self, sample_jupyter_notebook):
|
||||||
|
"""Test notebook creation uses kernelspec display_name when no name provided."""
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(sample_jupyter_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
assert notebook.name == "Python 3" # From kernelspec.display_name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_fallback_name(self):
|
||||||
|
"""Test fallback naming when kernelspec is missing."""
|
||||||
|
minimal_notebook = {
|
||||||
|
"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Test"]}],
|
||||||
|
"metadata": {}, # No kernelspec
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
notebook_json = json.dumps(minimal_notebook)
|
||||||
|
|
||||||
|
notebook = Notebook.from_ipynb_string(notebook_content=notebook_json, owner_id=user_id)
|
||||||
|
|
||||||
|
assert notebook.name == "Imported Notebook" # Fallback name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_string_invalid_json(self):
|
||||||
|
"""Test error handling for invalid JSON."""
|
||||||
|
user_id = uuid4()
|
||||||
|
invalid_json = "{ invalid json content"
|
||||||
|
|
||||||
|
from nbformat.reader import NotJSONError
|
||||||
|
|
||||||
|
with pytest.raises(NotJSONError):
|
||||||
|
Notebook.from_ipynb_string(notebook_content=invalid_json, owner_id=user_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(Notebook, "from_ipynb_zip_url")
|
||||||
|
async def test_create_tutorial_notebook_error_propagated(self, mock_from_zip_url, mock_session):
|
||||||
|
"""Test that errors are propagated when zip fetch fails."""
|
||||||
|
user_id = uuid4()
|
||||||
|
mock_from_zip_url.side_effect = Exception("Network error")
|
||||||
|
|
||||||
|
# Should raise the exception (not catch it)
|
||||||
|
with pytest.raises(Exception, match="Network error"):
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session)
|
||||||
|
|
||||||
|
# Verify error handling path was taken
|
||||||
|
mock_from_zip_url.assert_called_once()
|
||||||
|
mock_session.add.assert_not_called()
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
|
|
||||||
|
def test_generate_cell_name_code_cell(self):
|
||||||
|
"""Test cell name generation for code cells."""
|
||||||
|
from nbformat.notebooknode import NotebookNode
|
||||||
|
|
||||||
|
mock_cell = NotebookNode(
|
||||||
|
{"cell_type": "code", "source": 'import pandas as pd\nprint("Hello world")'}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = Notebook._generate_cell_name(mock_cell)
|
||||||
|
assert result == "Code Cell"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTutorialNotebookZipFunctionality:
|
||||||
|
"""Test cases for zip-based tutorial functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_zip_url_missing_notebook(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""Test error handling when notebook file is missing from zip."""
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
FileNotFoundError,
|
||||||
|
match="Notebook file 'super_random_tutorial_name.ipynb' not found in zip",
|
||||||
|
):
|
||||||
|
await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||||
|
owner_id=user_id,
|
||||||
|
notebook_filename="super_random_tutorial_name.ipynb",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notebook_from_ipynb_zip_url_download_failure(self):
|
||||||
|
"""Test error handling when zip download fails."""
|
||||||
|
user_id = uuid4()
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to download tutorial zip"):
|
||||||
|
await Notebook.from_ipynb_zip_url(
|
||||||
|
zip_url="https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/nonexistent_tutorial_name.zip",
|
||||||
|
owner_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tutorial_notebook_zip_success(self, mock_session):
|
||||||
|
"""Test successful tutorial notebook creation with zip."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
# Check that tutorial data directory is empty using storage-aware method
|
||||||
|
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||||
|
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||||
|
if tutorial_data_dir.exists():
|
||||||
|
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||||
|
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session)
|
||||||
|
|
||||||
|
items = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items) == 1, "Tutorial data directory should contain exactly one item"
|
||||||
|
assert items[0].is_dir(), "Tutorial data directory item should be a directory"
|
||||||
|
|
||||||
|
# Verify the structure inside the tutorial directory
|
||||||
|
tutorial_dir = items[0]
|
||||||
|
|
||||||
|
# Check for tutorial.ipynb file
|
||||||
|
notebook_file = tutorial_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file.exists(), f"tutorial.ipynb should exist in {tutorial_dir}"
|
||||||
|
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||||
|
|
||||||
|
# Check for data subfolder with contents
|
||||||
|
data_folder = tutorial_dir / "data"
|
||||||
|
assert data_folder.exists(), f"data subfolder should exist in {tutorial_dir}"
|
||||||
|
assert data_folder.is_dir(), "data should be a directory"
|
||||||
|
|
||||||
|
data_items = list(data_folder.iterdir())
|
||||||
|
assert len(data_items) > 0, (
|
||||||
|
f"data folder should contain files, but found {len(data_items)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_tutorial_notebook_with_force_refresh(self, mock_session):
|
||||||
|
"""Test tutorial notebook creation with force refresh."""
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
|
user_id = uuid4()
|
||||||
|
|
||||||
|
# Check that tutorial data directory is empty using storage-aware method
|
||||||
|
tutorial_data_dir_path = await get_tutorial_data_dir()
|
||||||
|
tutorial_data_dir = Path(tutorial_data_dir_path)
|
||||||
|
if tutorial_data_dir.exists():
|
||||||
|
assert not any(tutorial_data_dir.iterdir()), "Tutorial data directory should be empty"
|
||||||
|
|
||||||
|
# First creation (without force refresh)
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session, force_refresh=False)
|
||||||
|
|
||||||
|
items_first = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items_first) == 1, (
|
||||||
|
"Tutorial data directory should contain exactly one item after first creation"
|
||||||
|
)
|
||||||
|
first_dir = items_first[0]
|
||||||
|
assert first_dir.is_dir(), "Tutorial data directory item should be a directory"
|
||||||
|
|
||||||
|
# Verify the structure inside the tutorial directory (first creation)
|
||||||
|
notebook_file = first_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file.exists(), f"tutorial.ipynb should exist in {first_dir}"
|
||||||
|
assert notebook_file.is_file(), "tutorial.ipynb should be a file"
|
||||||
|
|
||||||
|
data_folder = first_dir / "data"
|
||||||
|
assert data_folder.exists(), f"data subfolder should exist in {first_dir}"
|
||||||
|
assert data_folder.is_dir(), "data should be a directory"
|
||||||
|
|
||||||
|
data_items = list(data_folder.iterdir())
|
||||||
|
assert len(data_items) > 0, (
|
||||||
|
f"data folder should contain files, but found {len(data_items)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture metadata from first creation
|
||||||
|
|
||||||
|
first_creation_metadata = {}
|
||||||
|
|
||||||
|
for file_path in first_dir.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(first_dir)
|
||||||
|
stat = file_path.stat()
|
||||||
|
|
||||||
|
# Store multiple metadata points
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
first_creation_metadata[str(relative_path)] = {
|
||||||
|
"mtime": stat.st_mtime,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"hash": hashlib.md5(content).hexdigest(),
|
||||||
|
"first_bytes": content[:100]
|
||||||
|
if content
|
||||||
|
else b"", # First 100 bytes as fingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
# Wait a moment to ensure different timestamps
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Force refresh - should create new files with different metadata
|
||||||
|
await _create_tutorial_notebook(user_id, mock_session, force_refresh=True)
|
||||||
|
|
||||||
|
items_second = list(tutorial_data_dir.iterdir())
|
||||||
|
assert len(items_second) == 1, (
|
||||||
|
"Tutorial data directory should contain exactly one item after force refresh"
|
||||||
|
)
|
||||||
|
second_dir = items_second[0]
|
||||||
|
|
||||||
|
# Verify the structure is maintained after force refresh
|
||||||
|
notebook_file_second = second_dir / "tutorial.ipynb"
|
||||||
|
assert notebook_file_second.exists(), (
|
||||||
|
f"tutorial.ipynb should exist in {second_dir} after force refresh"
|
||||||
|
)
|
||||||
|
assert notebook_file_second.is_file(), "tutorial.ipynb should be a file after force refresh"
|
||||||
|
|
||||||
|
data_folder_second = second_dir / "data"
|
||||||
|
assert data_folder_second.exists(), (
|
||||||
|
f"data subfolder should exist in {second_dir} after force refresh"
|
||||||
|
)
|
||||||
|
assert data_folder_second.is_dir(), "data should be a directory after force refresh"
|
||||||
|
|
||||||
|
data_items_second = list(data_folder_second.iterdir())
|
||||||
|
assert len(data_items_second) > 0, (
|
||||||
|
f"data folder should still contain files after force refresh, but found {len(data_items_second)} items"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare metadata to ensure files are actually different
|
||||||
|
files_with_changed_metadata = 0
|
||||||
|
|
||||||
|
for file_path in second_dir.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
relative_path = file_path.relative_to(second_dir)
|
||||||
|
relative_path_str = str(relative_path)
|
||||||
|
|
||||||
|
# File should exist from first creation
|
||||||
|
assert relative_path_str in first_creation_metadata, (
|
||||||
|
f"File {relative_path_str} missing from first creation"
|
||||||
|
)
|
||||||
|
|
||||||
|
old_metadata = first_creation_metadata[relative_path_str]
|
||||||
|
|
||||||
|
# Get new metadata
|
||||||
|
stat = file_path.stat()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
new_content = f.read()
|
||||||
|
|
||||||
|
new_metadata = {
|
||||||
|
"mtime": stat.st_mtime,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"hash": hashlib.md5(new_content).hexdigest(),
|
||||||
|
"first_bytes": new_content[:100] if new_content else b"",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if any metadata changed (indicating file was refreshed)
|
||||||
|
metadata_changed = (
|
||||||
|
new_metadata["mtime"] > old_metadata["mtime"] # Newer modification time
|
||||||
|
or new_metadata["hash"] != old_metadata["hash"] # Different content hash
|
||||||
|
or new_metadata["size"] != old_metadata["size"] # Different file size
|
||||||
|
or new_metadata["first_bytes"]
|
||||||
|
!= old_metadata["first_bytes"] # Different content
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata_changed:
|
||||||
|
files_with_changed_metadata += 1
|
||||||
|
|
||||||
|
# Assert that force refresh actually updated files
|
||||||
|
assert files_with_changed_metadata > 0, (
|
||||||
|
f"Force refresh should have updated at least some files, but all {len(first_creation_metadata)} "
|
||||||
|
f"files appear to have identical metadata. This suggests force refresh didn't work."
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.commit.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tutorial_zip_url_accessibility(self):
|
||||||
|
"""Test that the actual tutorial zip URL is accessible (integration test)."""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip",
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify it's a valid zip file by checking headers
|
||||||
|
assert response.headers.get("content-type") in [
|
||||||
|
"application/zip",
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/x-zip-compressed",
|
||||||
|
] or response.content.startswith(b"PK") # Zip file signature
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pytest.skip("Network request failed or zip not available - skipping integration test")
|
||||||
|
|
@ -46,6 +46,7 @@ dependencies = [
|
||||||
"matplotlib>=3.8.3,<4",
|
"matplotlib>=3.8.3,<4",
|
||||||
"networkx>=3.4.2,<4",
|
"networkx>=3.4.2,<4",
|
||||||
"lancedb>=0.24.0,<1.0.0",
|
"lancedb>=0.24.0,<1.0.0",
|
||||||
|
"nbformat>=5.7.0,<6.0.0",
|
||||||
"alembic>=1.13.3,<2",
|
"alembic>=1.13.3,<2",
|
||||||
"pre-commit>=4.0.1,<5",
|
"pre-commit>=4.0.1,<5",
|
||||||
"scikit-learn>=1.6.1,<2",
|
"scikit-learn>=1.6.1,<2",
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -831,6 +831,7 @@ dependencies = [
|
||||||
{ name = "limits" },
|
{ name = "limits" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
|
{ name = "nbformat" },
|
||||||
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "nltk" },
|
{ name = "nltk" },
|
||||||
|
|
@ -1012,6 +1013,7 @@ requires-dist = [
|
||||||
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
|
{ name = "mkdocstrings", extras = ["python"], marker = "extra == 'dev'", specifier = ">=0.26.2,<0.27" },
|
||||||
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
|
{ name = "modal", marker = "extra == 'distributed'", specifier = ">=1.0.5,<2.0.0" },
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.7.1,<2" },
|
||||||
|
{ name = "nbformat", specifier = ">=5.7.0,<6.0.0" },
|
||||||
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
|
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.28.0,<6" },
|
||||||
{ name = "networkx", specifier = ">=3.4.2,<4" },
|
{ name = "networkx", specifier = ">=3.4.2,<4" },
|
||||||
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue