Compare commits
14 commits
main
...
new_error_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4324df5a8b | ||
|
|
19e22d14b8 | ||
|
|
dc03a52541 | ||
|
|
f0ba618f0c | ||
|
|
725061fbef | ||
|
|
bf191ae6d0 | ||
|
|
9d423f5e16 | ||
|
|
411e9a6205 | ||
|
|
3429af32c2 | ||
|
|
9110a2b59b | ||
|
|
1cd0fe0dcf | ||
|
|
f2e96d5c62 | ||
|
|
1c378dabdb | ||
|
|
98882ba1d1 |
53 changed files with 4249 additions and 151 deletions
|
|
@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
|
|||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from cognee.exceptions import CogneeApiError
|
||||
from cognee.exceptions.enhanced_exceptions import CogneeBaseError
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||
from cognee.api.v1.settings.routers import get_settings_router
|
||||
|
|
@ -120,8 +121,27 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|||
)
|
||||
|
||||
|
||||
@app.exception_handler(CogneeBaseError)
|
||||
async def enhanced_exception_handler(_: Request, exc: CogneeBaseError) -> JSONResponse:
|
||||
"""
|
||||
Enhanced exception handler for the new exception hierarchy.
|
||||
Provides standardized error responses with rich context and user guidance.
|
||||
"""
|
||||
# Log the full stack trace for debugging
|
||||
logger.error(f"Enhanced exception caught: {exc.__class__.__name__}", exc_info=True)
|
||||
|
||||
# Create standardized error response
|
||||
error_response = {"error": exc.to_dict()}
|
||||
|
||||
return JSONResponse(status_code=exc.status_code, content=error_response)
|
||||
|
||||
|
||||
@app.exception_handler(CogneeApiError)
|
||||
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||
async def legacy_exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||
"""
|
||||
Legacy exception handler for backward compatibility.
|
||||
Handles old CogneeApiError instances with fallback formatting.
|
||||
"""
|
||||
detail = {}
|
||||
|
||||
if exc.name and exc.message and exc.status_code:
|
||||
|
|
@ -136,7 +156,54 @@ async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
|||
|
||||
# log the stack trace for easier serverside debugging
|
||||
logger.error(format_exc())
|
||||
return JSONResponse(status_code=status_code, content={"detail": detail["message"]})
|
||||
|
||||
# Convert to new format for consistency
|
||||
error_response = {
|
||||
"error": {
|
||||
"type": exc.__class__.__name__,
|
||||
"message": detail["message"],
|
||||
"technical_message": detail["message"],
|
||||
"suggestions": [
|
||||
"Check the logs for more details",
|
||||
"Try again or contact support if the issue persists",
|
||||
],
|
||||
"docs_link": "https://docs.cognee.ai/troubleshooting",
|
||||
"is_retryable": False,
|
||||
"context": {},
|
||||
"operation": None,
|
||||
}
|
||||
}
|
||||
|
||||
return JSONResponse(status_code=status_code, content=error_response)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""
|
||||
Global exception handler for any unhandled exceptions.
|
||||
Ensures all errors return a consistent format.
|
||||
"""
|
||||
logger.error(f"Unhandled exception in {request.url.path}: {str(exc)}", exc_info=True)
|
||||
|
||||
# Create a standardized error response for unexpected errors
|
||||
error_response = {
|
||||
"error": {
|
||||
"type": "UnexpectedError",
|
||||
"message": "An unexpected error occurred. Please try again.",
|
||||
"technical_message": str(exc) if app_environment != "prod" else "Internal server error",
|
||||
"suggestions": [
|
||||
"Try your request again",
|
||||
"Check if the issue persists",
|
||||
"Contact support if the problem continues",
|
||||
],
|
||||
"docs_link": "https://docs.cognee.ai/troubleshooting",
|
||||
"is_retryable": True,
|
||||
"context": {"path": str(request.url.path), "method": request.method},
|
||||
"operation": None,
|
||||
}
|
||||
}
|
||||
|
||||
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_response)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
|
|
|||
|
|
@ -15,14 +15,19 @@ async def add(
|
|||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
dataset_id: UUID = None,
|
||||
preferred_loaders: Optional[List[str]] = None,
|
||||
loader_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
Add data to Cognee for knowledge graph processing using a plugin-based loader system.
|
||||
|
||||
This is the first step in the Cognee workflow - it ingests raw data and prepares it
|
||||
for processing. The function accepts various data formats including text, files, and
|
||||
binary streams, then stores them in a specified dataset for further processing.
|
||||
|
||||
This version supports both the original ingestion system (for backward compatibility)
|
||||
and the new plugin-based loader system (when loader parameters are provided).
|
||||
|
||||
Prerequisites:
|
||||
- **LLM_API_KEY**: Must be set in environment variables for content processing
|
||||
- **Database Setup**: Relational and vector databases must be configured
|
||||
|
|
@ -38,16 +43,38 @@ async def add(
|
|||
- **Lists**: Multiple files or text strings in a single call
|
||||
|
||||
Supported File Formats:
|
||||
- Text files (.txt, .md, .csv)
|
||||
- PDFs (.pdf)
|
||||
- Text files (.txt, .md, .csv) - processed by text_loader
|
||||
- PDFs (.pdf) - processed by pypdf_loader (if available)
|
||||
- Images (.png, .jpg, .jpeg) - extracted via OCR/vision models
|
||||
- Audio files (.mp3, .wav) - transcribed to text
|
||||
- Code files (.py, .js, .ts, etc.) - parsed for structure and content
|
||||
- Office documents (.docx, .pptx)
|
||||
- Office documents (.docx, .pptx) - processed by unstructured_loader (if available)
|
||||
- Data files (.json, .jsonl, .parquet) - processed by dlt_loader (if available)
|
||||
|
||||
Workflow:
|
||||
Plugin System:
|
||||
The function automatically uses the best available loader for each file type.
|
||||
You can customize this behavior using the loader parameters:
|
||||
|
||||
```python
|
||||
# Use specific loaders in priority order
|
||||
await cognee.add(
|
||||
"/path/to/document.pdf",
|
||||
preferred_loaders=["pypdf_loader", "text_loader"]
|
||||
)
|
||||
|
||||
# Configure loader-specific options
|
||||
await cognee.add(
|
||||
"/path/to/document.pdf",
|
||||
loader_config={
|
||||
"pypdf_loader": {"strict": False},
|
||||
"unstructured_loader": {"strategy": "hi_res"}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Workflow:
|
||||
1. **Data Resolution**: Resolves file paths and validates accessibility
|
||||
2. **Content Extraction**: Extracts text content from various file formats
|
||||
2. **Content Extraction**: Uses plugin system or falls back to existing classification
|
||||
3. **Dataset Storage**: Stores processed content in the specified dataset
|
||||
4. **Metadata Tracking**: Records file metadata, timestamps, and user permissions
|
||||
5. **Permission Assignment**: Grants user read/write/delete/share permissions on dataset
|
||||
|
|
@ -70,6 +97,10 @@ async def add(
|
|||
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||
preferred_loaders: Optional list of loader names to try first (e.g., ["pypdf_loader", "text_loader"]).
|
||||
If not provided, uses default loader priority.
|
||||
loader_config: Optional configuration for specific loaders. Dictionary mapping loader names
|
||||
to their configuration options (e.g., {"pypdf_loader": {"strict": False}}).
|
||||
|
||||
Returns:
|
||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||
|
|
@ -138,10 +169,32 @@ async def add(
|
|||
UnsupportedFileTypeError: If file format cannot be processed
|
||||
InvalidValueError: If LLM_API_KEY is not set or invalid
|
||||
"""
|
||||
|
||||
# Determine which ingestion system to use
|
||||
# use_plugin_system = preferred_loaders is not None or loader_config is not None
|
||||
|
||||
# if use_plugin_system:
|
||||
# # Use new plugin-based ingestion system
|
||||
from cognee.tasks.ingestion.plugin_ingest_data import plugin_ingest_data
|
||||
|
||||
tasks = [
|
||||
Task(resolve_data_directories, include_subdirectories=True),
|
||||
Task(ingest_data, dataset_name, user, node_set, dataset_id),
|
||||
Task(
|
||||
plugin_ingest_data,
|
||||
dataset_name,
|
||||
user,
|
||||
node_set,
|
||||
dataset_id,
|
||||
preferred_loaders,
|
||||
loader_config,
|
||||
),
|
||||
]
|
||||
# else:
|
||||
# # Use existing ingestion system for backward compatibility
|
||||
# tasks = [
|
||||
# Task(resolve_data_directories, include_subdirectories=True),
|
||||
# Task(ingest_data, dataset_name, user, node_set, dataset_id),
|
||||
# ]
|
||||
|
||||
pipeline_run_info = None
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ import requests
|
|||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.exceptions import (
|
||||
UnsupportedFileFormatError,
|
||||
FileAccessError,
|
||||
DatasetNotFoundError,
|
||||
CogneeValidationError,
|
||||
CogneeSystemError,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -49,49 +56,143 @@ def get_add_router() -> APIRouter:
|
|||
- Any relevant metadata from the ingestion process
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Neither datasetId nor datasetName provided
|
||||
- **409 Conflict**: Error during add operation
|
||||
- **400 Bad Request**: Missing required parameters or invalid input
|
||||
- **422 Unprocessable Entity**: Unsupported file format or validation error
|
||||
- **403 Forbidden**: User doesn't have permission to add to dataset
|
||||
- **500 Internal Server Error**: System error during processing
|
||||
|
||||
## Notes
|
||||
- To add data to datasets not owned by the user, use dataset_id (when ENABLE_BACKEND_ACCESS_CONTROL is set to True)
|
||||
- GitHub repositories are cloned and all files are processed
|
||||
- HTTP URLs are fetched and their content is processed
|
||||
- The ALLOW_HTTP_REQUESTS environment variable controls URL processing
|
||||
- Enhanced error messages provide specific guidance for fixing issues
|
||||
"""
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
||||
# Input validation with enhanced exceptions
|
||||
if not datasetId and not datasetName:
|
||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||
raise CogneeValidationError(
|
||||
message="Either datasetId or datasetName must be provided",
|
||||
user_message="You must specify either a dataset name or dataset ID.",
|
||||
suggestions=[
|
||||
"Provide a datasetName parameter (e.g., 'my_dataset')",
|
||||
"Provide a datasetId parameter with a valid UUID",
|
||||
"Check the API documentation for parameter examples",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/api/add",
|
||||
context={"provided_dataset_name": datasetName, "provided_dataset_id": datasetId},
|
||||
operation="add",
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
isinstance(data, str)
|
||||
and data.startswith("http")
|
||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
||||
):
|
||||
if "github" in data:
|
||||
if not data or len(data) == 0:
|
||||
raise CogneeValidationError(
|
||||
message="No data provided for upload",
|
||||
user_message="You must provide data to add to the dataset.",
|
||||
suggestions=[
|
||||
"Upload one or more files",
|
||||
"Provide a valid URL (if URL processing is enabled)",
|
||||
"Check that your request includes the data parameter",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||
operation="add",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Adding {len(data)} items to dataset",
|
||||
extra={
|
||||
"dataset_name": datasetName,
|
||||
"dataset_id": datasetId,
|
||||
"user_id": user.id,
|
||||
"item_count": len(data),
|
||||
},
|
||||
)
|
||||
|
||||
# Handle URL-based data (GitHub repos, HTTP URLs)
|
||||
if (
|
||||
len(data) == 1
|
||||
and hasattr(data[0], "filename")
|
||||
and isinstance(data[0].filename, str)
|
||||
and data[0].filename.startswith("http")
|
||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
||||
):
|
||||
url = data[0].filename
|
||||
|
||||
if "github" in url:
|
||||
try:
|
||||
# Perform git clone if the URL is from GitHub
|
||||
repo_name = data.split("/")[-1].replace(".git", "")
|
||||
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
|
||||
repo_name = url.split("/")[-1].replace(".git", "")
|
||||
subprocess.run(["git", "clone", url, f".data/{repo_name}"], check=True)
|
||||
# TODO: Update add call with dataset info
|
||||
await cognee_add(
|
||||
result = await cognee_add(
|
||||
"data://.data/",
|
||||
f"{repo_name}",
|
||||
)
|
||||
else:
|
||||
# Fetch and store the data from other types of URL using curl
|
||||
response = requests.get(data)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise CogneeSystemError(
|
||||
message=f"Failed to clone GitHub repository: {e}",
|
||||
user_message=f"Could not clone the GitHub repository '{url}'.",
|
||||
suggestions=[
|
||||
"Check if the repository URL is correct",
|
||||
"Verify the repository is public or you have access",
|
||||
"Try cloning the repository manually to test access",
|
||||
],
|
||||
context={"url": url, "repo_name": repo_name, "error": str(e)},
|
||||
operation="add",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Fetch and store the data from other types of URL
|
||||
response = requests.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
file_data = await response.content()
|
||||
file_data = response.content
|
||||
# TODO: Update add call with dataset info
|
||||
return await cognee_add(file_data)
|
||||
else:
|
||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
||||
result = await cognee_add(file_data)
|
||||
except requests.RequestException as e:
|
||||
raise CogneeSystemError(
|
||||
message=f"Failed to fetch URL: {e}",
|
||||
user_message=f"Could not fetch content from '{url}'.",
|
||||
suggestions=[
|
||||
"Check if the URL is accessible",
|
||||
"Verify your internet connection",
|
||||
"Try accessing the URL in a browser",
|
||||
"Check if the URL requires authentication",
|
||||
],
|
||||
context={"url": url, "error": str(e)},
|
||||
operation="add",
|
||||
)
|
||||
else:
|
||||
# Handle regular file uploads
|
||||
# Validate file types before processing
|
||||
supported_extensions = [
|
||||
".txt",
|
||||
".pdf",
|
||||
".docx",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
]
|
||||
|
||||
return add_run.model_dump()
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
for file in data:
|
||||
if file.filename:
|
||||
file_ext = os.path.splitext(file.filename)[1].lower()
|
||||
if file_ext and file_ext not in supported_extensions:
|
||||
raise UnsupportedFileFormatError(
|
||||
file_path=file.filename, supported_formats=supported_extensions
|
||||
)
|
||||
|
||||
# Process the files
|
||||
result = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
||||
|
||||
logger.info(
|
||||
"Successfully added data to dataset",
|
||||
extra={"dataset_name": datasetName, "dataset_id": datasetId, "user_id": user.id},
|
||||
)
|
||||
|
||||
return result.model_dump() if hasattr(result, "model_dump") else result
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
|||
remove_queue,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.exceptions import (
|
||||
CogneeValidationError,
|
||||
EmptyDatasetError,
|
||||
DatasetNotFoundError,
|
||||
MissingAPIKeyError,
|
||||
NoDataToProcessError,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("api.cognify")
|
||||
|
|
@ -66,8 +73,10 @@ def get_cognify_router() -> APIRouter:
|
|||
- **Background execution**: Pipeline run metadata including pipeline_run_id for status monitoring via WebSocket subscription
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: When neither datasets nor dataset_ids are provided, or when specified datasets don't exist
|
||||
- **409 Conflict**: When processing fails due to system errors, missing LLM API keys, database connection failures, or corrupted content
|
||||
- **400 Bad Request**: Missing required parameters or invalid input
|
||||
- **422 Unprocessable Entity**: No data to process or validation errors
|
||||
- **404 Not Found**: Specified datasets don't exist
|
||||
- **500 Internal Server Error**: System errors, missing API keys, database connection failures
|
||||
|
||||
## Example Request
|
||||
```json
|
||||
|
|
@ -84,23 +93,53 @@ def get_cognify_router() -> APIRouter:
|
|||
## Next Steps
|
||||
After successful processing, use the search endpoints to query the generated knowledge graph for insights, relationships, and semantic search.
|
||||
"""
|
||||
# Input validation with enhanced exceptions
|
||||
if not payload.datasets and not payload.dataset_ids:
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "No datasets or dataset_ids provided"}
|
||||
raise CogneeValidationError(
|
||||
message="No datasets or dataset_ids provided",
|
||||
user_message="You must specify which datasets to process.",
|
||||
suggestions=[
|
||||
"Provide dataset names using the 'datasets' parameter",
|
||||
"Provide dataset UUIDs using the 'dataset_ids' parameter",
|
||||
"Use cognee.datasets() to see available datasets",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/api/cognify",
|
||||
context={
|
||||
"provided_datasets": payload.datasets,
|
||||
"provided_dataset_ids": payload.dataset_ids,
|
||||
},
|
||||
operation="cognify",
|
||||
)
|
||||
|
||||
# Check for LLM API key early to provide better error messaging
|
||||
llm_api_key = os.getenv("LLM_API_KEY")
|
||||
if not llm_api_key:
|
||||
raise MissingAPIKeyError(service="LLM", env_var="LLM_API_KEY")
|
||||
|
||||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||
|
||||
try:
|
||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets, user, run_in_background=payload.run_in_background
|
||||
)
|
||||
logger.info(
|
||||
f"Starting cognify process for user {user.id}",
|
||||
extra={
|
||||
"user_id": user.id,
|
||||
"datasets": datasets,
|
||||
"run_in_background": payload.run_in_background,
|
||||
},
|
||||
)
|
||||
|
||||
return cognify_run
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
# The enhanced exception handler will catch and format any errors from cognee_cognify
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets, user, run_in_background=payload.run_in_background
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Cognify process completed for user {user.id}",
|
||||
extra={"user_id": user.id, "datasets": datasets},
|
||||
)
|
||||
|
||||
return cognify_run
|
||||
|
||||
@router.websocket("/subscribe/{pipeline_run_id}")
|
||||
async def subscribe_to_cognify_info(websocket: WebSocket, pipeline_run_id: str):
|
||||
|
|
@ -124,6 +163,14 @@ def get_cognify_router() -> APIRouter:
|
|||
user_manager=user_manager,
|
||||
bearer=None,
|
||||
)
|
||||
logger.info(
|
||||
f"WebSocket user authenticated for pipeline {pipeline_run_id}",
|
||||
extra={
|
||||
"user_id": user.id,
|
||||
"user_email": user.email,
|
||||
"pipeline_run_id": str(pipeline_run_id),
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(f"Authentication failed: {str(error)}")
|
||||
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
|
|
@ -135,31 +182,43 @@ def get_cognify_router() -> APIRouter:
|
|||
|
||||
initialize_queue(pipeline_run_id)
|
||||
|
||||
while True:
|
||||
pipeline_run_info = get_from_queue(pipeline_run_id)
|
||||
try:
|
||||
# If the pipeline is already completed, send the completion status
|
||||
if isinstance(pipeline_run, PipelineRunCompleted):
|
||||
graph_data = await get_formatted_graph_data()
|
||||
pipeline_run.payload = {
|
||||
"nodes": graph_data.get("nodes", []),
|
||||
"edges": graph_data.get("edges", []),
|
||||
}
|
||||
|
||||
if not pipeline_run_info:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
await websocket.send_json(pipeline_run.model_dump())
|
||||
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
||||
return
|
||||
|
||||
if not isinstance(pipeline_run_info, PipelineRunInfo):
|
||||
continue
|
||||
# Stream pipeline updates
|
||||
while True:
|
||||
try:
|
||||
pipeline_run_info = await asyncio.wait_for(
|
||||
get_from_queue(pipeline_run_id), timeout=10.0
|
||||
)
|
||||
|
||||
try:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"pipeline_run_id": str(pipeline_run_info.pipeline_run_id),
|
||||
"status": pipeline_run_info.status,
|
||||
"payload": await get_formatted_graph_data(pipeline_run.dataset_id, user.id),
|
||||
}
|
||||
)
|
||||
if pipeline_run_info:
|
||||
await websocket.send_json(pipeline_run_info.model_dump())
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||
remove_queue(pipeline_run_id)
|
||||
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send a heartbeat to keep the connection alive
|
||||
await websocket.send_json({"type": "heartbeat"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket communication: {str(e)}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
remove_queue(pipeline_run_id)
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected for pipeline {pipeline_run_id}")
|
||||
except Exception as error:
|
||||
logger.error(f"WebSocket error: {str(error)}")
|
||||
finally:
|
||||
remove_queue(pipeline_run_id)
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -84,6 +84,12 @@ async def delete(
|
|||
# Get the content hash for deletion
|
||||
content_hash = data_point.content_hash
|
||||
|
||||
# Debug logging
|
||||
logger.info(
|
||||
f"🔍 Retrieved from database - data_id: {data_id}, content_hash: {content_hash}"
|
||||
)
|
||||
logger.info(f"🔍 Document name in database: {data_point.name}")
|
||||
|
||||
# Use the existing comprehensive deletion logic
|
||||
return await delete_single_document(content_hash, dataset.id, mode)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,24 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from fastapi import Depends, APIRouter
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||
from cognee.modules.data.methods import get_history
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.exceptions import UnsupportedSearchTypeError, InvalidQueryError, NoDataToProcessError
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Note: Datasets sent by name will only map to datasets owned by the request sender
|
||||
# To search for datasets not owned by the request sender dataset UUID is needed
|
||||
class SearchPayloadDTO(InDTO):
|
||||
search_type: SearchType
|
||||
datasets: Optional[list[str]] = None
|
||||
dataset_ids: Optional[list[UUID]] = None
|
||||
datasets: Optional[List[str]] = None
|
||||
dataset_ids: Optional[List[UUID]] = None
|
||||
query: str
|
||||
top_k: Optional[int] = 10
|
||||
|
||||
|
|
@ -24,36 +26,23 @@ class SearchPayloadDTO(InDTO):
|
|||
def get_search_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
class SearchHistoryItem(OutDTO):
|
||||
id: UUID
|
||||
text: str
|
||||
user: str
|
||||
created_at: datetime
|
||||
|
||||
@router.get("", response_model=list[SearchHistoryItem])
|
||||
@router.get("/history", response_model=list)
|
||||
async def get_search_history(user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Get search history for the authenticated user.
|
||||
|
||||
This endpoint retrieves the search history for the authenticated user,
|
||||
returning a list of previously executed searches with their timestamps.
|
||||
This endpoint retrieves the search history for the current user,
|
||||
showing previous queries and their results.
|
||||
|
||||
## Response
|
||||
Returns a list of search history items containing:
|
||||
- **id**: Unique identifier for the search
|
||||
- **text**: The search query text
|
||||
- **user**: User who performed the search
|
||||
- **created_at**: When the search was performed
|
||||
Returns a list of historical search queries and their metadata.
|
||||
|
||||
## Error Codes
|
||||
- **500 Internal Server Error**: Error retrieving search history
|
||||
- **500 Internal Server Error**: Database or system error while retrieving history
|
||||
"""
|
||||
try:
|
||||
history = await get_history(user.id, limit=0)
|
||||
|
||||
return history
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
# Remove try-catch to let enhanced exception handler deal with it
|
||||
history = await get_history(user.id, limit=0)
|
||||
return history
|
||||
|
||||
@router.post("", response_model=list)
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
|
|
@ -75,30 +64,59 @@ def get_search_router() -> APIRouter:
|
|||
Returns a list of search results containing relevant nodes from the graph.
|
||||
|
||||
## Error Codes
|
||||
- **409 Conflict**: Error during search operation
|
||||
- **403 Forbidden**: User doesn't have permission to search datasets (returns empty list)
|
||||
- **400 Bad Request**: Invalid query or search parameters
|
||||
- **404 Not Found**: No data found to search
|
||||
- **422 Unprocessable Entity**: Unsupported search type
|
||||
- **403 Forbidden**: User doesn't have permission to search datasets
|
||||
- **500 Internal Server Error**: System error during search
|
||||
|
||||
## Notes
|
||||
- Datasets sent by name will only map to datasets owned by the request sender
|
||||
- To search datasets not owned by the request sender, dataset UUID is needed
|
||||
- If permission is denied, returns empty list instead of error
|
||||
- Enhanced error messages provide actionable suggestions for fixing issues
|
||||
"""
|
||||
from cognee.api.v1.search import search as cognee_search
|
||||
|
||||
try:
|
||||
results = await cognee_search(
|
||||
query_text=payload.query,
|
||||
query_type=payload.search_type,
|
||||
user=user,
|
||||
datasets=payload.datasets,
|
||||
dataset_ids=payload.dataset_ids,
|
||||
top_k=payload.top_k,
|
||||
# Input validation with enhanced exceptions
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise InvalidQueryError(query=payload.query or "", reason="Query cannot be empty")
|
||||
|
||||
if len(payload.query.strip()) < 2:
|
||||
raise InvalidQueryError(
|
||||
query=payload.query, reason="Query must be at least 2 characters long"
|
||||
)
|
||||
|
||||
return results
|
||||
except PermissionDeniedError:
|
||||
# Check if search type is supported
|
||||
try:
|
||||
search_type = payload.search_type
|
||||
logger.info(
|
||||
f"Search type validated: {search_type.value}",
|
||||
extra={
|
||||
"search_type": search_type.value,
|
||||
"user_id": user.id,
|
||||
"query_length": len(payload.query),
|
||||
},
|
||||
)
|
||||
except ValueError:
|
||||
raise UnsupportedSearchTypeError(
|
||||
search_type=str(payload.search_type), supported_types=[t.value for t in SearchType]
|
||||
)
|
||||
|
||||
# Permission denied errors will be caught and handled by the enhanced exception handler
|
||||
# Other exceptions will also be properly formatted by the global handler
|
||||
results = await cognee_search(
|
||||
query_text=payload.query,
|
||||
query_type=payload.search_type,
|
||||
user=user,
|
||||
datasets=payload.datasets,
|
||||
dataset_ids=payload.dataset_ids,
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
|
||||
# If no results found, that's not necessarily an error, just return empty list
|
||||
if not results:
|
||||
return []
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
return results
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
"""
|
||||
Custom exceptions for the Cognee API.
|
||||
|
||||
This module defines a set of exceptions for handling various application errors,
|
||||
such as service failures, resource conflicts, and invalid operations.
|
||||
This module defines a comprehensive set of exceptions for handling various application errors,
|
||||
with enhanced error context, user-friendly messages, and actionable suggestions.
|
||||
"""
|
||||
|
||||
# Import original exceptions for backward compatibility
|
||||
from .exceptions import (
|
||||
CogneeApiError,
|
||||
ServiceError,
|
||||
|
|
@ -12,3 +13,83 @@ from .exceptions import (
|
|||
InvalidAttributeError,
|
||||
CriticalError,
|
||||
)
|
||||
|
||||
# Import enhanced exception hierarchy
|
||||
from .enhanced_exceptions import (
|
||||
CogneeBaseError,
|
||||
CogneeUserError,
|
||||
CogneeSystemError,
|
||||
CogneeTransientError,
|
||||
CogneeConfigurationError,
|
||||
CogneeValidationError,
|
||||
CogneeAuthenticationError,
|
||||
CogneePermissionError,
|
||||
CogneeNotFoundError,
|
||||
CogneeRateLimitError,
|
||||
)
|
||||
|
||||
# Import domain-specific exceptions
|
||||
from .domain_exceptions import (
|
||||
# Data/Input Errors
|
||||
UnsupportedFileFormatError,
|
||||
EmptyDatasetError,
|
||||
DatasetNotFoundError,
|
||||
InvalidQueryError,
|
||||
FileAccessError,
|
||||
# Processing Errors
|
||||
LLMConnectionError,
|
||||
LLMRateLimitError,
|
||||
ProcessingTimeoutError,
|
||||
DatabaseConnectionError,
|
||||
InsufficientResourcesError,
|
||||
# Configuration Errors
|
||||
MissingAPIKeyError,
|
||||
InvalidDatabaseConfigError,
|
||||
UnsupportedSearchTypeError,
|
||||
# Pipeline Errors
|
||||
PipelineExecutionError,
|
||||
DataExtractionError,
|
||||
NoDataToProcessError,
|
||||
)
|
||||
|
||||
# For backward compatibility, create aliases
|
||||
# These will allow existing code to continue working while we migrate
|
||||
DatasetNotFoundError_Legacy = InvalidValueError # For existing dataset not found errors
|
||||
PermissionDeniedError_Legacy = CogneeApiError # For existing permission errors
|
||||
|
||||
__all__ = [
|
||||
# Original exceptions (backward compatibility)
|
||||
"CogneeApiError",
|
||||
"ServiceError",
|
||||
"InvalidValueError",
|
||||
"InvalidAttributeError",
|
||||
"CriticalError",
|
||||
# Enhanced base exceptions
|
||||
"CogneeBaseError",
|
||||
"CogneeUserError",
|
||||
"CogneeSystemError",
|
||||
"CogneeTransientError",
|
||||
"CogneeConfigurationError",
|
||||
"CogneeValidationError",
|
||||
"CogneeAuthenticationError",
|
||||
"CogneePermissionError",
|
||||
"CogneeNotFoundError",
|
||||
"CogneeRateLimitError",
|
||||
# Domain-specific exceptions
|
||||
"UnsupportedFileFormatError",
|
||||
"EmptyDatasetError",
|
||||
"DatasetNotFoundError",
|
||||
"InvalidQueryError",
|
||||
"FileAccessError",
|
||||
"LLMConnectionError",
|
||||
"LLMRateLimitError",
|
||||
"ProcessingTimeoutError",
|
||||
"DatabaseConnectionError",
|
||||
"InsufficientResourcesError",
|
||||
"MissingAPIKeyError",
|
||||
"InvalidDatabaseConfigError",
|
||||
"UnsupportedSearchTypeError",
|
||||
"PipelineExecutionError",
|
||||
"DataExtractionError",
|
||||
"NoDataToProcessError",
|
||||
]
|
||||
|
|
|
|||
337
cognee/exceptions/domain_exceptions.py
Normal file
337
cognee/exceptions/domain_exceptions.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
from typing import List, Optional, Dict, Any
|
||||
from .enhanced_exceptions import (
|
||||
CogneeUserError,
|
||||
CogneeSystemError,
|
||||
CogneeTransientError,
|
||||
CogneeConfigurationError,
|
||||
CogneeValidationError,
|
||||
CogneeNotFoundError,
|
||||
CogneePermissionError,
|
||||
)
|
||||
|
||||
|
||||
# ========== DATA/INPUT ERRORS (User-fixable) ==========
|
||||
|
||||
|
||||
class UnsupportedFileFormatError(CogneeValidationError):
|
||||
"""File format not supported by Cognee"""
|
||||
|
||||
def __init__(self, file_path: str, supported_formats: List[str], **kwargs):
|
||||
super().__init__(
|
||||
message=f"File format not supported: {file_path}",
|
||||
user_message=f"The file '{file_path}' has an unsupported format.",
|
||||
suggestions=[
|
||||
f"Use one of these supported formats: {', '.join(supported_formats)}",
|
||||
"Convert your file to a supported format",
|
||||
"Check our documentation for the complete list of supported formats",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/file-formats",
|
||||
context={"file_path": file_path, "supported_formats": supported_formats},
|
||||
operation="add",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class EmptyDatasetError(CogneeValidationError):
|
||||
"""Dataset is empty or contains no processable content"""
|
||||
|
||||
def __init__(self, dataset_name: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Dataset '{dataset_name}' is empty",
|
||||
user_message=f"The dataset '{dataset_name}' contains no data to process.",
|
||||
suggestions=[
|
||||
"Add some data to the dataset first using cognee.add()",
|
||||
"Check if your files contain readable text content",
|
||||
"Verify that your data was uploaded successfully",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||
context={"dataset_name": dataset_name},
|
||||
operation="cognify",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DatasetNotFoundError(CogneeNotFoundError):
|
||||
"""Dataset not found or not accessible"""
|
||||
|
||||
def __init__(
|
||||
self, dataset_identifier: str, available_datasets: Optional[List[str]] = None, **kwargs
|
||||
):
|
||||
suggestions = ["Check the dataset name for typos"]
|
||||
if available_datasets:
|
||||
suggestions.extend(
|
||||
[
|
||||
f"Available datasets: {', '.join(available_datasets)}",
|
||||
"Use cognee.datasets() to see all your datasets",
|
||||
]
|
||||
)
|
||||
else:
|
||||
suggestions.append("Create the dataset first by adding data to it")
|
||||
|
||||
super().__init__(
|
||||
message=f"Dataset not found: {dataset_identifier}",
|
||||
user_message=f"Could not find dataset '{dataset_identifier}'.",
|
||||
suggestions=suggestions,
|
||||
docs_link="https://docs.cognee.ai/guides/datasets",
|
||||
context={
|
||||
"dataset_identifier": dataset_identifier,
|
||||
"available_datasets": available_datasets,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class InvalidQueryError(CogneeValidationError):
|
||||
"""Search query is invalid or malformed"""
|
||||
|
||||
def __init__(self, query: str, reason: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Invalid query: {reason}",
|
||||
user_message=f"Your search query '{query}' is invalid: {reason}",
|
||||
suggestions=[
|
||||
"Try rephrasing your query",
|
||||
"Use simpler, more specific terms",
|
||||
"Check our query examples in the documentation",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/search",
|
||||
context={"query": query, "reason": reason},
|
||||
operation="search",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class FileAccessError(CogneeUserError):
|
||||
"""Cannot access or read the specified file"""
|
||||
|
||||
def __init__(self, file_path: str, reason: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Cannot access file: {file_path} - {reason}",
|
||||
user_message=f"Unable to read the file '{file_path}': {reason}",
|
||||
suggestions=[
|
||||
"Check if the file exists at the specified path",
|
||||
"Verify you have read permissions for the file",
|
||||
"Ensure the file is not locked by another application",
|
||||
],
|
||||
context={"file_path": file_path, "reason": reason},
|
||||
operation="add",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ========== PROCESSING ERRORS (System/LLM errors) ==========
|
||||
|
||||
|
||||
class LLMConnectionError(CogneeTransientError):
|
||||
"""LLM service connection failure"""
|
||||
|
||||
def __init__(self, provider: str, model: str, reason: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"LLM connection failed: {provider}/{model} - {reason}",
|
||||
user_message=f"Cannot connect to the {provider} language model service.",
|
||||
suggestions=[
|
||||
"Check your internet connection",
|
||||
"Verify your API key is correct and has sufficient credits",
|
||||
"Try again in a few moments",
|
||||
"Check the service status page",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/troubleshooting/llm-connection",
|
||||
context={"provider": provider, "model": model, "reason": reason},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LLMRateLimitError(CogneeTransientError):
|
||||
"""LLM service rate limit exceeded"""
|
||||
|
||||
def __init__(self, provider: str, retry_after: Optional[int] = None, **kwargs):
|
||||
suggestions = [
|
||||
"Wait a moment before retrying",
|
||||
"Consider upgrading your API plan",
|
||||
"Use smaller batch sizes to reduce token usage",
|
||||
]
|
||||
if retry_after:
|
||||
suggestions.insert(0, f"Wait {retry_after} seconds before retrying")
|
||||
|
||||
super().__init__(
|
||||
message=f"Rate limit exceeded for {provider}",
|
||||
user_message=f"You've exceeded the rate limit for {provider}.",
|
||||
suggestions=suggestions,
|
||||
context={"provider": provider, "retry_after": retry_after},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class ProcessingTimeoutError(CogneeTransientError):
|
||||
"""Processing operation timed out"""
|
||||
|
||||
def __init__(self, operation: str, timeout_seconds: int, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Operation '{operation}' timed out after {timeout_seconds}s",
|
||||
user_message=f"The {operation} operation took too long and was cancelled.",
|
||||
suggestions=[
|
||||
"Try processing smaller amounts of data at a time",
|
||||
"Check your internet connection stability",
|
||||
"Retry the operation",
|
||||
"Use background processing for large datasets",
|
||||
],
|
||||
context={"operation": operation, "timeout_seconds": timeout_seconds},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConnectionError(CogneeSystemError):
|
||||
"""Database connection failure"""
|
||||
|
||||
def __init__(self, db_type: str, reason: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"{db_type} database connection failed: {reason}",
|
||||
user_message=f"Cannot connect to the {db_type} database.",
|
||||
suggestions=[
|
||||
"Check if the database service is running",
|
||||
"Verify database connection configuration",
|
||||
"Check network connectivity",
|
||||
"Contact support if the issue persists",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/troubleshooting/database",
|
||||
context={"db_type": db_type, "reason": reason},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class InsufficientResourcesError(CogneeSystemError):
|
||||
"""System has insufficient resources to complete the operation"""
|
||||
|
||||
def __init__(self, resource_type: str, required: str, available: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Insufficient {resource_type}: need {required}, have {available}",
|
||||
user_message=f"Not enough {resource_type} available to complete this operation.",
|
||||
suggestions=[
|
||||
"Try processing smaller amounts of data",
|
||||
"Free up system resources",
|
||||
"Wait for other operations to complete",
|
||||
"Consider upgrading your system resources",
|
||||
],
|
||||
context={"resource_type": resource_type, "required": required, "available": available},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ========== CONFIGURATION ERRORS ==========
|
||||
|
||||
|
||||
class MissingAPIKeyError(CogneeConfigurationError):
|
||||
"""Required API key is missing"""
|
||||
|
||||
def __init__(self, service: str, env_var: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Missing API key for {service}",
|
||||
user_message=f"API key for {service} is not configured.",
|
||||
suggestions=[
|
||||
f"Set the {env_var} environment variable",
|
||||
f"Add your {service} API key to your .env file",
|
||||
"Check the setup documentation for detailed instructions",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/setup/api-keys",
|
||||
context={"service": service, "env_var": env_var},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class InvalidDatabaseConfigError(CogneeConfigurationError):
|
||||
"""Database configuration is invalid"""
|
||||
|
||||
def __init__(self, db_type: str, config_issue: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Invalid {db_type} database configuration: {config_issue}",
|
||||
user_message=f"The {db_type} database is not properly configured: {config_issue}",
|
||||
suggestions=[
|
||||
"Check your database configuration settings",
|
||||
"Verify connection strings and credentials",
|
||||
"Review the database setup documentation",
|
||||
"Ensure the database server is accessible",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/setup/databases",
|
||||
context={"db_type": db_type, "config_issue": config_issue},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedSearchTypeError(CogneeValidationError):
|
||||
"""Search type is not supported"""
|
||||
|
||||
def __init__(self, search_type: str, supported_types: List[str], **kwargs):
|
||||
super().__init__(
|
||||
message=f"Unsupported search type: {search_type}",
|
||||
user_message=f"The search type '{search_type}' is not supported.",
|
||||
suggestions=[
|
||||
f"Use one of these supported search types: {', '.join(supported_types)}",
|
||||
"Check the search documentation for available types",
|
||||
"Try using GRAPH_COMPLETION for general queries",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/search-types",
|
||||
context={"search_type": search_type, "supported_types": supported_types},
|
||||
operation="search",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ========== PIPELINE ERRORS ==========
|
||||
|
||||
|
||||
class PipelineExecutionError(CogneeSystemError):
|
||||
"""Pipeline execution failed"""
|
||||
|
||||
def __init__(self, pipeline_name: str, task_name: str, error_details: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Pipeline '{pipeline_name}' failed at task '{task_name}': {error_details}",
|
||||
user_message=f"Processing failed during the {task_name} step.",
|
||||
suggestions=[
|
||||
"Check the logs for more detailed error information",
|
||||
"Verify your data is in a supported format",
|
||||
"Try processing smaller amounts of data",
|
||||
"Contact support if the issue persists",
|
||||
],
|
||||
context={
|
||||
"pipeline_name": pipeline_name,
|
||||
"task_name": task_name,
|
||||
"error_details": error_details,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DataExtractionError(CogneeSystemError):
|
||||
"""Failed to extract content from data"""
|
||||
|
||||
def __init__(self, source: str, reason: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"Data extraction failed for {source}: {reason}",
|
||||
user_message=f"Could not extract readable content from '{source}'.",
|
||||
suggestions=[
|
||||
"Verify the file is not corrupted",
|
||||
"Try converting to a different format",
|
||||
"Check if the file contains readable text",
|
||||
"Use a supported file format",
|
||||
],
|
||||
context={"source": source, "reason": reason},
|
||||
operation="add",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class NoDataToProcessError(CogneeValidationError):
|
||||
"""No data available to process"""
|
||||
|
||||
def __init__(self, operation: str, **kwargs):
|
||||
super().__init__(
|
||||
message=f"No data available for {operation}",
|
||||
user_message=f"There's no data to process for the {operation} operation.",
|
||||
suggestions=[
|
||||
"Add some data first using cognee.add()",
|
||||
"Check if your previous data upload was successful",
|
||||
"Verify the dataset contains processable content",
|
||||
],
|
||||
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||
context={"operation": operation},
|
||||
**kwargs,
|
||||
)
|
||||
184
cognee/exceptions/enhanced_exceptions.py
Normal file
184
cognee/exceptions/enhanced_exceptions.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
from typing import Dict, List, Optional, Any
|
||||
from fastapi import status
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CogneeBaseError(Exception):
|
||||
"""
|
||||
Base exception for all Cognee errors with enhanced context and user experience.
|
||||
|
||||
This class provides a foundation for all Cognee exceptions with:
|
||||
- Rich error context
|
||||
- User-friendly messages
|
||||
- Actionable suggestions
|
||||
- Documentation links
|
||||
- Retry information
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_message: Optional[str] = None,
|
||||
suggestions: Optional[List[str]] = None,
|
||||
docs_link: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
is_retryable: bool = False,
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
log_level: str = "ERROR",
|
||||
operation: Optional[str] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.user_message = user_message or message
|
||||
self.suggestions = suggestions or []
|
||||
self.docs_link = docs_link
|
||||
self.context = context or {}
|
||||
self.is_retryable = is_retryable
|
||||
self.status_code = status_code
|
||||
self.operation = operation
|
||||
|
||||
# Automatically log the exception
|
||||
if log_level == "ERROR":
|
||||
logger.error(f"CogneeError in {operation or 'unknown'}: {message}", extra=self.context)
|
||||
elif log_level == "WARNING":
|
||||
logger.warning(
|
||||
f"CogneeWarning in {operation or 'unknown'}: {message}", extra=self.context
|
||||
)
|
||||
elif log_level == "INFO":
|
||||
logger.info(f"CogneeInfo in {operation or 'unknown'}: {message}", extra=self.context)
|
||||
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}: {self.message}"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for API responses"""
|
||||
return {
|
||||
"type": self.__class__.__name__,
|
||||
"message": self.user_message,
|
||||
"technical_message": self.message,
|
||||
"suggestions": self.suggestions,
|
||||
"docs_link": self.docs_link,
|
||||
"is_retryable": self.is_retryable,
|
||||
"context": self.context,
|
||||
"operation": self.operation,
|
||||
}
|
||||
|
||||
|
||||
class CogneeUserError(CogneeBaseError):
|
||||
"""
|
||||
User-fixable errors (4xx status codes).
|
||||
|
||||
These are errors caused by user input or actions that can be corrected
|
||||
by the user. Examples: invalid file format, missing required field.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_400_BAD_REQUEST)
|
||||
kwargs.setdefault("log_level", "WARNING")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeSystemError(CogneeBaseError):
|
||||
"""
|
||||
System/infrastructure errors (5xx status codes).
|
||||
|
||||
These are errors caused by system issues that require technical intervention.
|
||||
Examples: database connection failure, service unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
kwargs.setdefault("log_level", "ERROR")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeTransientError(CogneeBaseError):
|
||||
"""
|
||||
Temporary/retryable errors.
|
||||
|
||||
These are errors that might succeed if retried, often due to temporary
|
||||
resource constraints or network issues.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||
kwargs.setdefault("is_retryable", True)
|
||||
kwargs.setdefault("log_level", "WARNING")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeConfigurationError(CogneeBaseError):
|
||||
"""
|
||||
Setup/configuration errors.
|
||||
|
||||
These are errors related to missing or invalid configuration that
|
||||
prevent the system from operating correctly.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
kwargs.setdefault("log_level", "ERROR")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeValidationError(CogneeUserError):
|
||||
"""
|
||||
Input validation errors.
|
||||
|
||||
Specific type of user error for invalid input data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeAuthenticationError(CogneeUserError):
|
||||
"""
|
||||
Authentication and authorization errors.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_401_UNAUTHORIZED)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneePermissionError(CogneeUserError):
|
||||
"""
|
||||
Permission denied errors.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_403_FORBIDDEN)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeNotFoundError(CogneeUserError):
|
||||
"""
|
||||
Resource not found errors.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_404_NOT_FOUND)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogneeRateLimitError(CogneeTransientError):
|
||||
"""
|
||||
Rate limiting errors.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("status_code", status.HTTP_429_TOO_MANY_REQUESTS)
|
||||
kwargs.setdefault(
|
||||
"suggestions",
|
||||
[
|
||||
"Wait a moment before retrying",
|
||||
"Check your API rate limits",
|
||||
"Consider using smaller batch sizes",
|
||||
],
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
237
cognee/infrastructure/loaders/LoaderEngine.py
Normal file
237
cognee/infrastructure/loaders/LoaderEngine.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import os
|
||||
import importlib.util
|
||||
from typing import Dict, List, Optional
|
||||
from .LoaderInterface import LoaderInterface
|
||||
from .models.LoaderResult import LoaderResult
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class LoaderEngine:
|
||||
"""
|
||||
Main loader engine for managing file loaders.
|
||||
|
||||
Follows cognee's adapter pattern similar to database engines,
|
||||
providing a centralized system for file loading operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader_directories: List[str],
|
||||
default_loader_priority: List[str],
|
||||
fallback_loader: str = "text_loader",
|
||||
enable_dependency_validation: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the loader engine.
|
||||
|
||||
Args:
|
||||
loader_directories: Directories to search for loader implementations
|
||||
default_loader_priority: Priority order for loader selection
|
||||
fallback_loader: Default loader to use when no other matches
|
||||
enable_dependency_validation: Whether to validate loader dependencies
|
||||
"""
|
||||
self._loaders: Dict[str, LoaderInterface] = {}
|
||||
self._extension_map: Dict[str, List[LoaderInterface]] = {}
|
||||
self._mime_type_map: Dict[str, List[LoaderInterface]] = {}
|
||||
self.loader_directories = loader_directories
|
||||
self.default_loader_priority = default_loader_priority
|
||||
self.fallback_loader = fallback_loader
|
||||
self.enable_dependency_validation = enable_dependency_validation
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||
"""
|
||||
Register a loader with the engine.
|
||||
|
||||
Args:
|
||||
loader: LoaderInterface implementation to register
|
||||
|
||||
Returns:
|
||||
True if loader was registered successfully, False otherwise
|
||||
"""
|
||||
# Validate dependencies if enabled
|
||||
if self.enable_dependency_validation and not loader.validate_dependencies():
|
||||
self.logger.warning(
|
||||
f"Skipping loader '{loader.loader_name}' - missing dependencies: "
|
||||
f"{loader.get_dependencies()}"
|
||||
)
|
||||
return False
|
||||
|
||||
self._loaders[loader.loader_name] = loader
|
||||
|
||||
# Map extensions to loaders
|
||||
for ext in loader.supported_extensions:
|
||||
ext_lower = ext.lower()
|
||||
if ext_lower not in self._extension_map:
|
||||
self._extension_map[ext_lower] = []
|
||||
self._extension_map[ext_lower].append(loader)
|
||||
|
||||
# Map mime types to loaders
|
||||
for mime_type in loader.supported_mime_types:
|
||||
if mime_type not in self._mime_type_map:
|
||||
self._mime_type_map[mime_type] = []
|
||||
self._mime_type_map[mime_type].append(loader)
|
||||
|
||||
self.logger.info(f"Registered loader: {loader.loader_name}")
|
||||
return True
|
||||
|
||||
def get_loader(
|
||||
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None
|
||||
) -> Optional[LoaderInterface]:
|
||||
"""
|
||||
Get appropriate loader for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
|
||||
Returns:
|
||||
LoaderInterface that can handle the file, or None if not found
|
||||
"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# Try preferred loaders first
|
||||
if preferred_loaders:
|
||||
for loader_name in preferred_loaders:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try priority order
|
||||
for loader_name in self.default_loader_priority:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try mime type mapping
|
||||
if mime_type and mime_type in self._mime_type_map:
|
||||
for loader in self._mime_type_map[mime_type]:
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try extension mapping
|
||||
if ext in self._extension_map:
|
||||
for loader in self._extension_map[ext]:
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Fallback loader
|
||||
if self.fallback_loader in self._loaders:
|
||||
fallback = self._loaders[self.fallback_loader]
|
||||
if fallback.can_handle(file_path, mime_type):
|
||||
return fallback
|
||||
|
||||
return None
|
||||
|
||||
async def load_file(
|
||||
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None, **kwargs
|
||||
) -> LoaderResult:
|
||||
"""
|
||||
Load file using appropriate loader.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
**kwargs: Additional loader-specific configuration
|
||||
|
||||
Returns:
|
||||
LoaderResult containing processed content and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no suitable loader is found
|
||||
Exception: If file processing fails
|
||||
"""
|
||||
loader = self.get_loader(file_path, mime_type, preferred_loaders)
|
||||
if not loader:
|
||||
raise ValueError(f"No loader found for file: {file_path}")
|
||||
|
||||
self.logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
||||
return await loader.load(file_path, **kwargs)
|
||||
|
||||
def discover_loaders(self):
|
||||
"""
|
||||
Auto-discover loaders from configured directories.
|
||||
|
||||
Scans loader directories for Python modules containing
|
||||
LoaderInterface implementations and registers them.
|
||||
"""
|
||||
for directory in self.loader_directories:
|
||||
if os.path.exists(directory):
|
||||
self._discover_in_directory(directory)
|
||||
|
||||
def _discover_in_directory(self, directory: str):
|
||||
"""
|
||||
Discover loaders in a specific directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to scan for loader implementations
|
||||
"""
|
||||
try:
|
||||
for file_name in os.listdir(directory):
|
||||
if file_name.endswith(".py") and not file_name.startswith("_"):
|
||||
module_name = file_name[:-3]
|
||||
file_path = os.path.join(directory, file_name)
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Look for loader classes
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(attr, type)
|
||||
and issubclass(attr, LoaderInterface)
|
||||
and attr != LoaderInterface
|
||||
):
|
||||
# Instantiate and register the loader
|
||||
try:
|
||||
loader_instance = attr()
|
||||
self.register_loader(loader_instance)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to instantiate loader {attr_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to load module {module_name}: {e}")
|
||||
|
||||
except OSError as e:
|
||||
self.logger.warning(f"Failed to scan directory {directory}: {e}")
|
||||
|
||||
def get_available_loaders(self) -> List[str]:
|
||||
"""
|
||||
Get list of available loader names.
|
||||
|
||||
Returns:
|
||||
List of registered loader names
|
||||
"""
|
||||
return list(self._loaders.keys())
|
||||
|
||||
def get_loader_info(self, loader_name: str) -> Dict[str, any]:
|
||||
"""
|
||||
Get information about a specific loader.
|
||||
|
||||
Args:
|
||||
loader_name: Name of the loader to inspect
|
||||
|
||||
Returns:
|
||||
Dictionary containing loader information
|
||||
"""
|
||||
if loader_name not in self._loaders:
|
||||
return {}
|
||||
|
||||
loader = self._loaders[loader_name]
|
||||
return {
|
||||
"name": loader.loader_name,
|
||||
"extensions": loader.supported_extensions,
|
||||
"mime_types": loader.supported_mime_types,
|
||||
"dependencies": loader.get_dependencies(),
|
||||
"available": loader.validate_dependencies(),
|
||||
}
|
||||
101
cognee/infrastructure/loaders/LoaderInterface.py
Normal file
101
cognee/infrastructure/loaders/LoaderInterface.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from .models.LoaderResult import LoaderResult
|
||||
|
||||
|
||||
class LoaderInterface(ABC):
|
||||
"""
|
||||
Base interface for all file loaders in cognee.
|
||||
|
||||
This interface follows cognee's established pattern for database adapters,
|
||||
ensuring consistent behavior across all loader implementations.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""
|
||||
List of file extensions this loader supports.
|
||||
|
||||
Returns:
|
||||
List of extensions including the dot (e.g., ['.txt', '.md'])
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""
|
||||
List of MIME types this loader supports.
|
||||
|
||||
Returns:
|
||||
List of MIME type strings (e.g., ['text/plain', 'application/pdf'])
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def loader_name(self) -> str:
|
||||
"""
|
||||
Unique name identifier for this loader.
|
||||
|
||||
Returns:
|
||||
String identifier used for registration and configuration
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
|
||||
Returns:
|
||||
True if this loader can process the file, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process the file, returning standardized result.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
**kwargs: Additional loader-specific configuration
|
||||
|
||||
Returns:
|
||||
LoaderResult containing processed content and metadata
|
||||
|
||||
Raises:
|
||||
Exception: If file cannot be processed
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_dependencies(self) -> List[str]:
|
||||
"""
|
||||
Optional: Return list of required dependencies for this loader.
|
||||
|
||||
Returns:
|
||||
List of package names with optional version specifications
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_dependencies(self) -> bool:
|
||||
"""
|
||||
Check if all required dependencies are available.
|
||||
|
||||
Returns:
|
||||
True if all dependencies are installed, False otherwise
|
||||
"""
|
||||
for dep in self.get_dependencies():
|
||||
# Extract package name from version specification
|
||||
package_name = dep.split(">=")[0].split("==")[0].split("<")[0]
|
||||
try:
|
||||
__import__(package_name)
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
19
cognee/infrastructure/loaders/__init__.py
Normal file
19
cognee/infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
File loader infrastructure for cognee.
|
||||
|
||||
This package provides a plugin-based system for loading different file formats
|
||||
into cognee, following the same patterns as database adapters.
|
||||
|
||||
Main exports:
|
||||
- get_loader_engine(): Factory function to get configured loader engine
|
||||
- use_loader(): Register custom loaders at runtime
|
||||
- LoaderInterface: Base interface for implementing loaders
|
||||
- LoaderResult, ContentType: Data models for loader results
|
||||
"""
|
||||
|
||||
from .get_loader_engine import get_loader_engine
|
||||
from .use_loader import use_loader
|
||||
from .LoaderInterface import LoaderInterface
|
||||
from .models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
__all__ = ["get_loader_engine", "use_loader", "LoaderInterface", "LoaderResult", "ContentType"]
|
||||
57
cognee/infrastructure/loaders/config.py
Normal file
57
cognee/infrastructure/loaders/config.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
||||
class LoaderConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for file loader system.
|
||||
|
||||
Follows cognee's pattern using pydantic_settings.BaseSettings for
|
||||
environment variable support and validation.
|
||||
"""
|
||||
|
||||
loader_directories: List[str] = [
|
||||
get_absolute_path("infrastructure/loaders/core"),
|
||||
get_absolute_path("infrastructure/loaders/external"),
|
||||
]
|
||||
default_loader_priority: List[str] = [
|
||||
"text_loader",
|
||||
"pypdf_loader",
|
||||
"unstructured_loader",
|
||||
"dlt_loader",
|
||||
]
|
||||
auto_discover: bool = True
|
||||
fallback_loader: str = "text_loader"
|
||||
enable_dependency_validation: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", env_prefix="LOADER_")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert configuration to dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict containing all loader configuration settings
|
||||
"""
|
||||
return {
|
||||
"loader_directories": self.loader_directories,
|
||||
"default_loader_priority": self.default_loader_priority,
|
||||
"auto_discover": self.auto_discover,
|
||||
"fallback_loader": self.fallback_loader,
|
||||
"enable_dependency_validation": self.enable_dependency_validation,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_loader_config() -> LoaderConfig:
|
||||
"""
|
||||
Get cached loader configuration.
|
||||
|
||||
Uses LRU cache following cognee's pattern for configuration objects.
|
||||
|
||||
Returns:
|
||||
LoaderConfig instance with current settings
|
||||
"""
|
||||
return LoaderConfig()
|
||||
5
cognee/infrastructure/loaders/core/__init__.py
Normal file
5
cognee/infrastructure/loaders/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Core loader implementations that are always available."""
|
||||
|
||||
from .text_loader import TextLoader
|
||||
|
||||
__all__ = ["TextLoader"]
|
||||
128
cognee/infrastructure/loaders/core/text_loader.py
Normal file
128
cognee/infrastructure/loaders/core/text_loader.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
import os
|
||||
from typing import List
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
|
||||
class TextLoader(LoaderInterface):
|
||||
"""
|
||||
Core text file loader that handles basic text file formats.
|
||||
|
||||
This loader is always available and serves as the fallback for
|
||||
text-based files when no specialized loader is available.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", ".log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
"text/yaml",
|
||||
"application/yaml",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "text_loader"
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
# Check by extension
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self.supported_extensions:
|
||||
return True
|
||||
|
||||
# Check by MIME type
|
||||
if mime_type and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
# As fallback loader, can attempt to handle any text-like file
|
||||
# This is useful when other loaders fail
|
||||
try:
|
||||
# Quick check if file appears to be text
|
||||
with open(file_path, "rb") as f:
|
||||
sample = f.read(512)
|
||||
# Simple heuristic: if most bytes are printable, consider it text
|
||||
if sample:
|
||||
try:
|
||||
sample.decode("utf-8")
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
sample.decode("latin-1")
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process the text file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding=encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with fallback encoding
|
||||
if encoding == "utf-8":
|
||||
return await self.load(file_path, encoding="latin-1", **kwargs)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Extract basic metadata
|
||||
file_stat = os.stat(file_path)
|
||||
metadata = {
|
||||
"name": os.path.basename(file_path),
|
||||
"size": file_stat.st_size,
|
||||
"extension": os.path.splitext(file_path)[1],
|
||||
"encoding": encoding,
|
||||
"loader": self.loader_name,
|
||||
"lines": len(content.splitlines()) if content else 0,
|
||||
"characters": len(content),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
content_type=ContentType.TEXT,
|
||||
source_info={"file_path": file_path, "encoding": encoding},
|
||||
)
|
||||
49
cognee/infrastructure/loaders/create_loader_engine.py
Normal file
49
cognee/infrastructure/loaders/create_loader_engine.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from typing import List
|
||||
from .LoaderEngine import LoaderEngine
|
||||
from .supported_loaders import supported_loaders
|
||||
|
||||
|
||||
def create_loader_engine(
|
||||
loader_directories: List[str],
|
||||
default_loader_priority: List[str],
|
||||
auto_discover: bool = True,
|
||||
fallback_loader: str = "text_loader",
|
||||
enable_dependency_validation: bool = True,
|
||||
) -> LoaderEngine:
|
||||
"""
|
||||
Create loader engine with given configuration.
|
||||
|
||||
Follows cognee's pattern for engine creation functions used
|
||||
in database adapters.
|
||||
|
||||
Args:
|
||||
loader_directories: Directories to search for loader implementations
|
||||
default_loader_priority: Priority order for loader selection
|
||||
auto_discover: Whether to auto-discover loaders from directories
|
||||
fallback_loader: Default loader to use when no other matches
|
||||
enable_dependency_validation: Whether to validate loader dependencies
|
||||
|
||||
Returns:
|
||||
Configured LoaderEngine instance
|
||||
"""
|
||||
engine = LoaderEngine(
|
||||
loader_directories=loader_directories,
|
||||
default_loader_priority=default_loader_priority,
|
||||
fallback_loader=fallback_loader,
|
||||
enable_dependency_validation=enable_dependency_validation,
|
||||
)
|
||||
|
||||
# Register supported loaders from registry
|
||||
for loader_name, loader_class in supported_loaders.items():
|
||||
try:
|
||||
loader_instance = loader_class()
|
||||
engine.register_loader(loader_instance)
|
||||
except Exception as e:
|
||||
# Log but don't fail - allow engine to continue with other loaders
|
||||
engine.logger.warning(f"Failed to register loader {loader_name}: {e}")
|
||||
|
||||
# Auto-discover loaders if enabled
|
||||
if auto_discover:
|
||||
engine.discover_loaders()
|
||||
|
||||
return engine
|
||||
34
cognee/infrastructure/loaders/external/__init__.py
vendored
Normal file
34
cognee/infrastructure/loaders/external/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
External loader implementations for cognee.
|
||||
|
||||
This module contains loaders that depend on external libraries:
|
||||
- pypdf_loader: PDF processing using pypdf
|
||||
- unstructured_loader: Document processing using unstructured
|
||||
- dlt_loader: Data lake/warehouse integration using DLT
|
||||
|
||||
These loaders are optional and only available if their dependencies are installed.
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
|
||||
# Conditional imports based on dependency availability
|
||||
try:
|
||||
from .pypdf_loader import PyPdfLoader
|
||||
|
||||
__all__.append("PyPdfLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .unstructured_loader import UnstructuredLoader
|
||||
|
||||
__all__.append("UnstructuredLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .dlt_loader import DltLoader
|
||||
|
||||
__all__.append("DltLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
203
cognee/infrastructure/loaders/external/dlt_loader.py
vendored
Normal file
203
cognee/infrastructure/loaders/external/dlt_loader.py
vendored
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class DltLoader(LoaderInterface):
|
||||
"""
|
||||
Data loader using DLT (Data Load Tool) for various data sources.
|
||||
|
||||
Supports loading data from REST APIs, databases, cloud storage,
|
||||
and other data sources through DLT pipelines.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return [
|
||||
".dlt", # DLT pipeline configuration
|
||||
".json", # JSON data
|
||||
".jsonl", # JSON Lines
|
||||
".csv", # CSV data
|
||||
".parquet", # Parquet files
|
||||
".yaml", # YAML configuration
|
||||
".yml", # YAML configuration
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return [
|
||||
"application/json",
|
||||
"application/x-ndjson", # JSON Lines
|
||||
"text/csv",
|
||||
"application/x-parquet",
|
||||
"application/yaml",
|
||||
"text/yaml",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "dlt_loader"
|
||||
|
||||
def get_dependencies(self) -> List[str]:
|
||||
return ["dlt>=0.4.0"]
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if file_ext not in self.supported_extensions:
|
||||
return False
|
||||
|
||||
# Check MIME type if provided
|
||||
if mime_type and mime_type not in self.supported_mime_types:
|
||||
return False
|
||||
|
||||
# Validate dependencies
|
||||
return self.validate_dependencies()
|
||||
|
||||
async def load(self, file_path: str, source_type: str = "auto", **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load data using DLT pipeline.
|
||||
|
||||
Args:
|
||||
file_path: Path to the data file or DLT configuration
|
||||
source_type: Type of data source ("auto", "json", "csv", "parquet", "api")
|
||||
**kwargs: Additional DLT pipeline configuration
|
||||
|
||||
Returns:
|
||||
LoaderResult with loaded data and metadata
|
||||
|
||||
Raises:
|
||||
ImportError: If DLT is not installed
|
||||
Exception: If data loading fails
|
||||
"""
|
||||
try:
|
||||
import dlt
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"dlt is required for data loading. Install with: pip install dlt"
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.logger.info(f"Loading data with DLT: {file_path}")
|
||||
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
file_name = os.path.basename(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
# Determine source type if auto
|
||||
if source_type == "auto":
|
||||
if file_ext == ".json":
|
||||
source_type = "json"
|
||||
elif file_ext == ".jsonl":
|
||||
source_type = "jsonl"
|
||||
elif file_ext == ".csv":
|
||||
source_type = "csv"
|
||||
elif file_ext == ".parquet":
|
||||
source_type = "parquet"
|
||||
elif file_ext in [".yaml", ".yml"]:
|
||||
source_type = "yaml"
|
||||
else:
|
||||
source_type = "file"
|
||||
|
||||
# Load data based on source type
|
||||
if source_type == "json":
|
||||
content = self._load_json(file_path)
|
||||
elif source_type == "jsonl":
|
||||
content = self._load_jsonl(file_path)
|
||||
elif source_type == "csv":
|
||||
content = self._load_csv(file_path)
|
||||
elif source_type == "parquet":
|
||||
content = self._load_parquet(file_path)
|
||||
elif source_type == "yaml":
|
||||
content = self._load_yaml(file_path)
|
||||
else:
|
||||
# Default: read as text
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Determine content type
|
||||
if isinstance(content, (dict, list)):
|
||||
content_type = ContentType.STRUCTURED
|
||||
text_content = str(content)
|
||||
else:
|
||||
content_type = ContentType.TEXT
|
||||
text_content = content
|
||||
|
||||
# Gather metadata
|
||||
metadata = {
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": file_ext,
|
||||
"loader": self.loader_name,
|
||||
"source_type": source_type,
|
||||
"dlt_version": dlt.__version__,
|
||||
}
|
||||
|
||||
# Add data-specific metadata
|
||||
if isinstance(content, list):
|
||||
metadata["records_count"] = len(content)
|
||||
elif isinstance(content, dict):
|
||||
metadata["keys_count"] = len(content)
|
||||
|
||||
return LoaderResult(
|
||||
content=text_content,
|
||||
metadata=metadata,
|
||||
content_type=content_type,
|
||||
chunks=[text_content], # Single chunk for now
|
||||
source_info={
|
||||
"file_path": file_path,
|
||||
"source_type": source_type,
|
||||
"raw_data": content if isinstance(content, (dict, list)) else None,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load data with DLT from {file_path}: {e}")
|
||||
raise Exception(f"DLT data loading failed: {e}") from e
|
||||
|
||||
def _load_json(self, file_path: str) -> Dict[str, Any]:
|
||||
"""Load JSON file."""
|
||||
import json
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_jsonl(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Load JSON Lines file."""
|
||||
import json
|
||||
|
||||
data = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
def _load_csv(self, file_path: str) -> str:
|
||||
"""Load CSV file as text."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def _load_parquet(self, file_path: str) -> str:
|
||||
"""Load Parquet file (requires pandas)."""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df.to_string()
|
||||
except ImportError:
|
||||
# Fallback: read as binary and convert to string representation
|
||||
with open(file_path, "rb") as f:
|
||||
return f"<Parquet file: {os.path.basename(file_path)}, size: {len(f.read())} bytes>"
|
||||
|
||||
def _load_yaml(self, file_path: str) -> str:
|
||||
"""Load YAML file as text."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
127
cognee/infrastructure/loaders/external/pypdf_loader.py
vendored
Normal file
127
cognee/infrastructure/loaders/external/pypdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
from typing import List
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class PyPdfLoader(LoaderInterface):
|
||||
"""
|
||||
PDF loader using pypdf library.
|
||||
|
||||
Extracts text content from PDF files page by page, providing
|
||||
structured page information and handling PDF-specific errors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return [".pdf"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return ["application/pdf"]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "pypdf_loader"
|
||||
|
||||
def get_dependencies(self) -> List[str]:
|
||||
return ["pypdf>=4.0.0"]
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
if not file_path.lower().endswith(".pdf"):
|
||||
return False
|
||||
|
||||
# Check MIME type if provided
|
||||
if mime_type and mime_type != "application/pdf":
|
||||
return False
|
||||
|
||||
# Validate dependencies
|
||||
return self.validate_dependencies()
|
||||
|
||||
async def load(self, file_path: str, strict: bool = False, **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load PDF file and extract text content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF file
|
||||
strict: Whether to use strict mode for PDF reading
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content and metadata
|
||||
|
||||
Raises:
|
||||
ImportError: If pypdf is not installed
|
||||
Exception: If PDF processing fails
|
||||
"""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pypdf is required for PDF processing. Install with: pip install pypdf"
|
||||
) from e
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
self.logger.info(f"Reading PDF: {file_path}")
|
||||
reader = PdfReader(file, strict=strict)
|
||||
|
||||
content_parts = []
|
||||
page_texts = []
|
||||
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
try:
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip(): # Only add non-empty pages
|
||||
page_texts.append(page_text)
|
||||
content_parts.append(f"Page {page_num}:\n{page_text}\n")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to extract text from page {page_num}: {e}")
|
||||
continue
|
||||
|
||||
# Combine all content
|
||||
full_content = "\n".join(content_parts)
|
||||
|
||||
# Gather metadata
|
||||
metadata = {
|
||||
"name": os.path.basename(file_path),
|
||||
"size": os.path.getsize(file_path),
|
||||
"extension": ".pdf",
|
||||
"pages": len(reader.pages),
|
||||
"pages_with_text": len(page_texts),
|
||||
"loader": self.loader_name,
|
||||
}
|
||||
|
||||
# Add PDF metadata if available
|
||||
if reader.metadata:
|
||||
metadata["pdf_metadata"] = {
|
||||
"title": reader.metadata.get("/Title", ""),
|
||||
"author": reader.metadata.get("/Author", ""),
|
||||
"subject": reader.metadata.get("/Subject", ""),
|
||||
"creator": reader.metadata.get("/Creator", ""),
|
||||
"producer": reader.metadata.get("/Producer", ""),
|
||||
"creation_date": str(reader.metadata.get("/CreationDate", "")),
|
||||
"modification_date": str(reader.metadata.get("/ModDate", "")),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=full_content,
|
||||
metadata=metadata,
|
||||
content_type=ContentType.TEXT,
|
||||
chunks=page_texts, # Pre-chunked by page
|
||||
source_info={
|
||||
"file_path": file_path,
|
||||
"pages": len(reader.pages),
|
||||
"strict_mode": strict,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process PDF {file_path}: {e}")
|
||||
raise Exception(f"PDF processing failed: {e}") from e
|
||||
168
cognee/infrastructure/loaders/external/unstructured_loader.py
vendored
Normal file
168
cognee/infrastructure/loaders/external/unstructured_loader.py
vendored
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import os
|
||||
from typing import List
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class UnstructuredLoader(LoaderInterface):
|
||||
"""
|
||||
Document loader using the unstructured library.
|
||||
|
||||
Handles various document formats including docx, pptx, xlsx, odt, etc.
|
||||
Uses the unstructured library's auto-partition functionality.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return [
|
||||
".docx",
|
||||
".doc",
|
||||
".odt", # Word documents
|
||||
".xlsx",
|
||||
".xls",
|
||||
".ods", # Spreadsheets
|
||||
".pptx",
|
||||
".ppt",
|
||||
".odp", # Presentations
|
||||
".rtf",
|
||||
".html",
|
||||
".htm", # Rich text and HTML
|
||||
".eml",
|
||||
".msg", # Email formats
|
||||
".epub", # eBooks
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx
|
||||
"application/msword", # doc
|
||||
"application/vnd.oasis.opendocument.text", # odt
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx
|
||||
"application/vnd.ms-excel", # xls
|
||||
"application/vnd.oasis.opendocument.spreadsheet", # ods
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx
|
||||
"application/vnd.ms-powerpoint", # ppt
|
||||
"application/vnd.oasis.opendocument.presentation", # odp
|
||||
"application/rtf", # rtf
|
||||
"text/html", # html
|
||||
"message/rfc822", # eml
|
||||
"application/epub+zip", # epub
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "unstructured_loader"
|
||||
|
||||
def get_dependencies(self) -> List[str]:
|
||||
return ["unstructured>=0.10.0"]
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if file_ext not in self.supported_extensions:
|
||||
return False
|
||||
|
||||
# Check MIME type if provided
|
||||
if mime_type and mime_type not in self.supported_mime_types:
|
||||
return False
|
||||
|
||||
# Validate dependencies
|
||||
return self.validate_dependencies()
|
||||
|
||||
async def load(self, file_path: str, strategy: str = "auto", **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load document using unstructured library.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||
**kwargs: Additional arguments passed to unstructured partition
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content and metadata
|
||||
|
||||
Raises:
|
||||
ImportError: If unstructured is not installed
|
||||
Exception: If document processing fails
|
||||
"""
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"unstructured is required for document processing. "
|
||||
"Install with: pip install unstructured"
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.logger.info(f"Processing document: {file_path}")
|
||||
|
||||
# Determine content type from file extension
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# Get file size and basic info
|
||||
file_size = os.path.getsize(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
# Set partitioning parameters
|
||||
partition_kwargs = {"filename": file_path, "strategy": strategy, **kwargs}
|
||||
|
||||
# Use partition to extract elements
|
||||
elements = partition(**partition_kwargs)
|
||||
|
||||
# Process elements into text content
|
||||
text_parts = []
|
||||
element_info = []
|
||||
|
||||
for element in elements:
|
||||
element_text = str(element).strip()
|
||||
if element_text:
|
||||
text_parts.append(element_text)
|
||||
element_info.append(
|
||||
{
|
||||
"type": type(element).__name__,
|
||||
"text": element_text[:100] + "..."
|
||||
if len(element_text) > 100
|
||||
else element_text,
|
||||
}
|
||||
)
|
||||
|
||||
# Combine all text content
|
||||
full_content = "\n\n".join(text_parts)
|
||||
|
||||
# Determine content type based on structure
|
||||
content_type = ContentType.STRUCTURED if len(element_info) > 1 else ContentType.TEXT
|
||||
|
||||
# Gather metadata
|
||||
metadata = {
|
||||
"name": file_name,
|
||||
"size": file_size,
|
||||
"extension": file_ext,
|
||||
"loader": self.loader_name,
|
||||
"elements_count": len(elements),
|
||||
"text_elements_count": len(text_parts),
|
||||
"strategy": strategy,
|
||||
"element_types": list(set(info["type"] for info in element_info)),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=full_content,
|
||||
metadata=metadata,
|
||||
content_type=content_type,
|
||||
chunks=text_parts, # Pre-chunked by elements
|
||||
source_info={
|
||||
"file_path": file_path,
|
||||
"strategy": strategy,
|
||||
"elements": element_info[:10], # First 10 elements for debugging
|
||||
"total_elements": len(elements),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process document {file_path}: {e}")
|
||||
raise Exception(f"Document processing failed: {e}") from e
|
||||
20
cognee/infrastructure/loaders/get_loader_engine.py
Normal file
20
cognee/infrastructure/loaders/get_loader_engine.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from functools import lru_cache
|
||||
from .config import get_loader_config
|
||||
from .LoaderEngine import LoaderEngine
|
||||
from .create_loader_engine import create_loader_engine
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_loader_engine() -> LoaderEngine:
|
||||
"""
|
||||
Factory function to get loader engine.
|
||||
|
||||
Follows cognee's pattern with @lru_cache for efficient reuse
|
||||
of engine instances. Configuration is loaded from environment
|
||||
variables and settings.
|
||||
|
||||
Returns:
|
||||
Cached LoaderEngine instance configured with current settings
|
||||
"""
|
||||
config = get_loader_config()
|
||||
return create_loader_engine(**config.to_dict())
|
||||
47
cognee/infrastructure/loaders/models/LoaderResult.py
Normal file
47
cognee/infrastructure/loaders/models/LoaderResult.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Content type classification for loaded files"""
|
||||
|
||||
TEXT = "text"
|
||||
STRUCTURED = "structured"
|
||||
BINARY = "binary"
|
||||
|
||||
|
||||
class LoaderResult(BaseModel):
|
||||
"""
|
||||
Standardized output format for all file loaders.
|
||||
|
||||
This model ensures consistent data structure across all loader implementations,
|
||||
following cognee's pattern of using Pydantic models for data validation.
|
||||
"""
|
||||
|
||||
content: str # Primary text content extracted from file
|
||||
metadata: Dict[str, Any] # File metadata (name, size, type, loader info, etc.)
|
||||
content_type: ContentType # Content classification
|
||||
chunks: Optional[List[str]] = None # Pre-chunked content if available
|
||||
source_info: Optional[Dict[str, Any]] = None # Source-specific information
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the loader result to a dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict containing all loader result data with string-serialized content_type
|
||||
"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"metadata": self.metadata,
|
||||
"content_type": self.content_type.value,
|
||||
"source_info": self.source_info or {},
|
||||
"chunks": self.chunks,
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration following cognee patterns"""
|
||||
|
||||
use_enum_values = True
|
||||
validate_assignment = True
|
||||
3
cognee/infrastructure/loaders/models/__init__.py
Normal file
3
cognee/infrastructure/loaders/models/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .LoaderResult import LoaderResult, ContentType
|
||||
|
||||
__all__ = ["LoaderResult", "ContentType"]
|
||||
3
cognee/infrastructure/loaders/supported_loaders.py
Normal file
3
cognee/infrastructure/loaders/supported_loaders.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Registry for loader implementations
|
||||
# Follows cognee's pattern used in databases/vector/supported_databases.py
|
||||
supported_loaders = {}
|
||||
22
cognee/infrastructure/loaders/use_loader.py
Normal file
22
cognee/infrastructure/loaders/use_loader.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from .supported_loaders import supported_loaders
|
||||
|
||||
|
||||
def use_loader(loader_name: str, loader_class):
|
||||
"""
|
||||
Register a loader at runtime.
|
||||
|
||||
Follows cognee's pattern used in databases for adapter registration.
|
||||
This allows external packages and custom loaders to be registered
|
||||
into the loader system.
|
||||
|
||||
Args:
|
||||
loader_name: Unique name for the loader
|
||||
loader_class: Loader class implementing LoaderInterface
|
||||
|
||||
Example:
|
||||
from cognee.infrastructure.loaders import use_loader
|
||||
from my_package import MyCustomLoader
|
||||
|
||||
use_loader("my_custom_loader", MyCustomLoader)
|
||||
"""
|
||||
supported_loaders[loader_name] = loader_class
|
||||
|
|
@ -30,9 +30,38 @@ class BinaryData(IngestionData):
|
|||
|
||||
async def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = await get_file_metadata(self.data)
|
||||
# Handle case where file might be closed
|
||||
if hasattr(self.data, "closed") and self.data.closed:
|
||||
# Try to reopen the file if we have a file path
|
||||
if hasattr(self.data, "name") and self.data.name:
|
||||
try:
|
||||
with open(self.data.name, "rb") as reopened_file:
|
||||
self.metadata = await get_file_metadata(reopened_file)
|
||||
except (OSError, FileNotFoundError):
|
||||
# If we can't reopen, create minimal metadata
|
||||
self.metadata = {
|
||||
"name": self.name or "unknown",
|
||||
"file_path": getattr(self.data, "name", "unknown"),
|
||||
"extension": "txt",
|
||||
"mime_type": "text/plain",
|
||||
"content_hash": f"closed_file_{id(self.data)}",
|
||||
"file_size": 0,
|
||||
}
|
||||
else:
|
||||
# Create minimal metadata when file is closed and no path available
|
||||
self.metadata = {
|
||||
"name": self.name or "unknown",
|
||||
"file_path": "unknown",
|
||||
"extension": "txt",
|
||||
"mime_type": "text/plain",
|
||||
"content_hash": f"closed_file_{id(self.data)}",
|
||||
"file_size": 0,
|
||||
}
|
||||
else:
|
||||
# File is still open, proceed normally
|
||||
self.metadata = await get_file_metadata(self.data)
|
||||
|
||||
if self.metadata["name"] is None:
|
||||
if self.metadata.get("name") is None:
|
||||
self.metadata["name"] = self.name
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.infrastructure.files import get_file_metadata, FileMetadata
|
||||
from cognee.infrastructure.utils import run_sync
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from .IngestionData import IngestionData
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,12 @@ class TextData(IngestionData):
|
|||
self.data = data
|
||||
|
||||
def get_identifier(self):
|
||||
keywords = extract_keywords(self.data)
|
||||
import hashlib
|
||||
|
||||
return "text/plain" + "_" + "|".join(keywords)
|
||||
content_bytes = self.data.encode("utf-8")
|
||||
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||
|
||||
return "text/plain" + "_" + content_hash
|
||||
|
||||
def get_metadata(self):
|
||||
self.ensure_metadata()
|
||||
|
|
@ -27,7 +30,20 @@ class TextData(IngestionData):
|
|||
|
||||
def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
import hashlib
|
||||
|
||||
keywords = extract_keywords(self.data)
|
||||
content_bytes = self.data.encode("utf-8")
|
||||
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||
|
||||
self.metadata = {
|
||||
"keywords": keywords,
|
||||
"content_hash": content_hash,
|
||||
"content_type": "text/plain",
|
||||
"mime_type": "text/plain",
|
||||
"extension": "txt",
|
||||
"file_size": len(content_bytes),
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_data(self):
|
||||
|
|
|
|||
|
|
@ -71,6 +71,25 @@ async def cognee_pipeline(
|
|||
if cognee_pipeline.first_run:
|
||||
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
|
||||
|
||||
# Ensure NLTK data is downloaded on first run
|
||||
def ensure_nltk_data():
|
||||
"""Download required NLTK data if not already present."""
|
||||
try:
|
||||
import nltk
|
||||
|
||||
# Download essential NLTK data used by the system
|
||||
nltk.download("punkt_tab", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
nltk.download("words", quiet=True)
|
||||
logger.info("NLTK data initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize NLTK data: {e}")
|
||||
|
||||
ensure_nltk_data()
|
||||
|
||||
# Test LLM and Embedding configuration once before running Cognee
|
||||
await test_llm_connection()
|
||||
await test_embedding_connection()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,14 @@ import inspect
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.exceptions import (
|
||||
PipelineExecutionError,
|
||||
CogneeTransientError,
|
||||
CogneeSystemError,
|
||||
CogneeUserError,
|
||||
LLMConnectionError,
|
||||
DatabaseConnectionError,
|
||||
)
|
||||
|
||||
from ..tasks.task import Task
|
||||
|
||||
|
|
@ -16,15 +24,33 @@ async def handle_task(
|
|||
user: User,
|
||||
context: dict = None,
|
||||
):
|
||||
"""Handle common task workflow with logging, telemetry, and error handling around the core execution logic."""
|
||||
task_type = running_task.task_type
|
||||
"""
|
||||
Handle common task workflow with enhanced error handling and recovery strategies.
|
||||
|
||||
This function provides comprehensive error handling for pipeline tasks with:
|
||||
- Context-aware error reporting
|
||||
- Automatic retry for transient errors
|
||||
- Detailed error logging and telemetry
|
||||
- User-friendly error messages
|
||||
"""
|
||||
task_type = running_task.task_type
|
||||
task_name = running_task.executable.__name__
|
||||
|
||||
logger.info(
|
||||
f"{task_type} task started: `{task_name}`",
|
||||
extra={
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"user_id": user.id,
|
||||
"context": context,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"{task_type} task started: `{running_task.executable.__name__}`")
|
||||
send_telemetry(
|
||||
f"{task_type} Task Started",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
"task_name": task_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -35,36 +61,151 @@ async def handle_task(
|
|||
if has_context:
|
||||
args.append(context)
|
||||
|
||||
try:
|
||||
async for result_data in running_task.execute(args, next_task_batch_size):
|
||||
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
|
||||
yield result
|
||||
# Retry configuration for transient errors
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`")
|
||||
send_telemetry(
|
||||
f"{task_type} Task Completed",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{task_type} task errored: `{running_task.executable.__name__}`\n{str(error)}\n",
|
||||
exc_info=True,
|
||||
)
|
||||
send_telemetry(
|
||||
f"{task_type} Task Errored",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
async for result_data in running_task.execute(args, next_task_batch_size):
|
||||
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
|
||||
yield result
|
||||
|
||||
logger.info(
|
||||
f"{task_type} task completed: `{task_name}`",
|
||||
extra={
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"user_id": user.id,
|
||||
"retry_count": retry_count,
|
||||
},
|
||||
)
|
||||
|
||||
send_telemetry(
|
||||
f"{task_type} Task Completed",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": task_name,
|
||||
"retry_count": retry_count,
|
||||
},
|
||||
)
|
||||
return # Success, exit retry loop
|
||||
|
||||
except CogneeTransientError as error:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
logger.warning(
|
||||
f"Transient error in {task_type} task `{task_name}`, retrying ({retry_count}/{max_retries}): {error}",
|
||||
extra={
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"user_id": user.id,
|
||||
"retry_count": retry_count,
|
||||
"error_type": error.__class__.__name__,
|
||||
},
|
||||
)
|
||||
# Exponential backoff for retries
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(2**retry_count)
|
||||
continue
|
||||
else:
|
||||
# Max retries exceeded, raise enhanced error
|
||||
raise PipelineExecutionError(
|
||||
pipeline_name=f"{task_type}_pipeline",
|
||||
task_name=task_name,
|
||||
error_details=f"Max retries ({max_retries}) exceeded for transient error: {error}",
|
||||
)
|
||||
|
||||
except (CogneeUserError, CogneeSystemError) as error:
|
||||
# These errors shouldn't be retried, re-raise as pipeline execution error
|
||||
logger.error(
|
||||
f"{task_type} task failed: `{task_name}` - {error.__class__.__name__}: {error}",
|
||||
extra={
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"user_id": user.id,
|
||||
"error_type": error.__class__.__name__,
|
||||
"error_context": getattr(error, "context", {}),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
send_telemetry(
|
||||
f"{task_type} Task Errored",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": task_name,
|
||||
"error_type": error.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
||||
# Wrap in pipeline execution error with additional context
|
||||
raise PipelineExecutionError(
|
||||
pipeline_name=f"{task_type}_pipeline",
|
||||
task_name=task_name,
|
||||
error_details=f"{error.__class__.__name__}: {error}",
|
||||
context={
|
||||
"original_error": error.__class__.__name__,
|
||||
"original_context": getattr(error, "context", {}),
|
||||
"user_id": user.id,
|
||||
"task_args": str(args)[:200], # Truncate for logging
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
# Unexpected error, wrap in enhanced exception
|
||||
logger.error(
|
||||
f"{task_type} task encountered unexpected error: `{task_name}` - {error}",
|
||||
extra={
|
||||
"task_type": task_type,
|
||||
"task_name": task_name,
|
||||
"user_id": user.id,
|
||||
"error_type": error.__class__.__name__,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
send_telemetry(
|
||||
f"{task_type} Task Errored",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": task_name,
|
||||
"error_type": error.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
||||
# Check if this might be a known error type we can categorize
|
||||
error_message = str(error).lower()
|
||||
if any(term in error_message for term in ["connection", "timeout", "network"]):
|
||||
if (
|
||||
"llm" in error_message
|
||||
or "openai" in error_message
|
||||
or "anthropic" in error_message
|
||||
):
|
||||
raise LLMConnectionError(provider="Unknown", model="Unknown", reason=str(error))
|
||||
elif "database" in error_message or "sql" in error_message:
|
||||
raise DatabaseConnectionError(db_type="Unknown", reason=str(error))
|
||||
|
||||
# Default to pipeline execution error
|
||||
raise PipelineExecutionError(
|
||||
pipeline_name=f"{task_type}_pipeline",
|
||||
task_name=task_name,
|
||||
error_details=f"Unexpected error: {error}",
|
||||
context={
|
||||
"error_type": error.__class__.__name__,
|
||||
"user_id": user.id,
|
||||
"task_args": str(args)[:200], # Truncate for logging
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None, context: dict = None):
|
||||
"""Base function to execute tasks in a pipeline, handling task type detection and execution."""
|
||||
"""
|
||||
Base function to execute tasks in a pipeline with enhanced error handling.
|
||||
|
||||
Provides comprehensive error handling, logging, and recovery strategies for pipeline execution.
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
return
|
||||
|
|
|
|||
11
cognee/tasks/ingestion/adapters/__init__.py
Normal file
11
cognee/tasks/ingestion/adapters/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""
|
||||
Adapters for bridging the new loader system with existing ingestion pipeline.
|
||||
|
||||
This module provides compatibility layers to integrate the plugin-based loader
|
||||
system with cognee's existing data processing pipeline while maintaining
|
||||
backward compatibility and preserving permission logic.
|
||||
"""
|
||||
|
||||
from .loader_to_ingestion_adapter import LoaderToIngestionAdapter
|
||||
|
||||
__all__ = ["LoaderToIngestionAdapter"]
|
||||
241
cognee/tasks/ingestion/adapters/loader_to_ingestion_adapter.py
Normal file
241
cognee/tasks/ingestion/adapters/loader_to_ingestion_adapter.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
import os
|
||||
import tempfile
|
||||
from typing import BinaryIO, Union, Optional, Any
|
||||
from io import StringIO, BytesIO
|
||||
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
from cognee.modules.ingestion.data_types import IngestionData, TextData, BinaryData
|
||||
from cognee.infrastructure.files import get_file_metadata
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class LoaderResultToIngestionData(IngestionData):
|
||||
"""
|
||||
Adapter class that wraps LoaderResult to be compatible with IngestionData interface.
|
||||
|
||||
This maintains backward compatibility with existing cognee ingestion pipeline
|
||||
while enabling the new loader system.
|
||||
"""
|
||||
|
||||
def __init__(self, loader_result: LoaderResult, original_file_path: str = None):
|
||||
self.loader_result = loader_result
|
||||
self.original_file_path = original_file_path
|
||||
self._cached_metadata = None
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def get_identifier(self) -> str:
|
||||
"""
|
||||
Get content identifier for deduplication.
|
||||
|
||||
Always generates hash from content to ensure consistency with existing system.
|
||||
"""
|
||||
# Always generate hash from content for consistency
|
||||
import hashlib
|
||||
|
||||
content_bytes = self.loader_result.content.encode("utf-8")
|
||||
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||
|
||||
# Add content type prefix for better identification
|
||||
content_type = self.loader_result.content_type.value
|
||||
return f"{content_type}_{content_hash}"
|
||||
|
||||
def get_metadata(self) -> dict:
|
||||
"""
|
||||
Get file metadata in the format expected by existing pipeline.
|
||||
|
||||
Converts LoaderResult metadata to the format used by IngestionData.
|
||||
"""
|
||||
if self._cached_metadata is not None:
|
||||
return self._cached_metadata
|
||||
|
||||
# Start with loader result metadata
|
||||
metadata = self.loader_result.metadata.copy()
|
||||
|
||||
# Ensure required fields are present
|
||||
if "name" not in metadata:
|
||||
if self.original_file_path:
|
||||
metadata["name"] = os.path.basename(self.original_file_path)
|
||||
else:
|
||||
# Generate name from content hash
|
||||
content_hash = self.get_identifier().split("_")[-1][:8]
|
||||
ext = metadata.get("extension", ".txt")
|
||||
metadata["name"] = f"content_{content_hash}{ext}"
|
||||
|
||||
if "content_hash" not in metadata:
|
||||
# Store content hash without prefix for compatibility with deletion system
|
||||
identifier = self.get_identifier()
|
||||
if "_" in identifier:
|
||||
# Remove content type prefix (e.g., "text_abc123" -> "abc123")
|
||||
metadata["content_hash"] = identifier.split("_", 1)[-1]
|
||||
else:
|
||||
metadata["content_hash"] = identifier
|
||||
|
||||
if "file_path" not in metadata and self.original_file_path:
|
||||
metadata["file_path"] = self.original_file_path
|
||||
|
||||
# Add mime type if not present
|
||||
if "mime_type" not in metadata:
|
||||
ext = metadata.get("extension", "").lower()
|
||||
mime_type_map = {
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
".csv": "text/csv",
|
||||
".json": "application/json",
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
}
|
||||
metadata["mime_type"] = mime_type_map.get(ext, "application/octet-stream")
|
||||
|
||||
self._cached_metadata = metadata
|
||||
return metadata
|
||||
|
||||
def get_data(self) -> Union[str, BinaryIO]:
|
||||
"""
|
||||
Get data content in format expected by existing pipeline.
|
||||
|
||||
Returns content as string for text data or creates a file-like object
|
||||
for binary data to maintain compatibility.
|
||||
"""
|
||||
if self.loader_result.content_type == ContentType.TEXT:
|
||||
return self.loader_result.content
|
||||
|
||||
# For structured or binary content, return as string for now
|
||||
# The existing pipeline expects text content for processing
|
||||
return self.loader_result.content
|
||||
|
||||
|
||||
class LoaderToIngestionAdapter:
|
||||
"""
|
||||
Adapter that bridges the new loader system with existing ingestion pipeline.
|
||||
|
||||
This class provides methods to process files using the loader system
|
||||
while maintaining compatibility with the existing IngestionData interface.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def process_file_with_loaders(
|
||||
self,
|
||||
file_path: str,
|
||||
s3fs: Optional[Any] = None,
|
||||
preferred_loaders: Optional[list] = None,
|
||||
loader_config: Optional[dict] = None,
|
||||
) -> IngestionData:
|
||||
"""
|
||||
Process a file using the loader system and return IngestionData.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to process
|
||||
s3fs: S3 filesystem (for compatibility with existing code)
|
||||
preferred_loaders: List of preferred loader names
|
||||
loader_config: Configuration for specific loaders
|
||||
|
||||
Returns:
|
||||
IngestionData compatible object
|
||||
|
||||
Raises:
|
||||
Exception: If no loader can handle the file
|
||||
"""
|
||||
from cognee.infrastructure.loaders import get_loader_engine
|
||||
|
||||
try:
|
||||
# Get the loader engine
|
||||
engine = get_loader_engine()
|
||||
|
||||
# Determine MIME type if possible
|
||||
mime_type = None
|
||||
try:
|
||||
import mimetypes
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Load file using loader system
|
||||
self.logger.info(f"Processing file with loaders: {file_path}")
|
||||
|
||||
# Extract loader-specific config if provided
|
||||
kwargs = {}
|
||||
if loader_config:
|
||||
# Find the first available loader that matches our preferred loaders
|
||||
loader = engine.get_loader(file_path, mime_type, preferred_loaders)
|
||||
if loader and loader.loader_name in loader_config:
|
||||
kwargs = loader_config[loader.loader_name]
|
||||
|
||||
loader_result = await engine.load_file(
|
||||
file_path, mime_type=mime_type, preferred_loaders=preferred_loaders, **kwargs
|
||||
)
|
||||
|
||||
# Convert to IngestionData compatible format
|
||||
return LoaderResultToIngestionData(loader_result, file_path)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Loader system failed for {file_path}: {e}")
|
||||
# Fallback to existing classification system
|
||||
return await self._fallback_to_existing_system(file_path, s3fs)
|
||||
|
||||
async def _fallback_to_existing_system(
|
||||
self, file_path: str, s3fs: Optional[Any] = None
|
||||
) -> IngestionData:
|
||||
"""
|
||||
Fallback to existing ingestion.classify() system for backward compatibility.
|
||||
|
||||
This ensures that even if the loader system fails, we can still process
|
||||
files using the original classification method.
|
||||
"""
|
||||
from cognee.modules.ingestion import classify
|
||||
|
||||
self.logger.info(f"Falling back to existing classification system for: {file_path}")
|
||||
|
||||
# Open file and classify using existing system
|
||||
if file_path.startswith("s3://"):
|
||||
if s3fs:
|
||||
with s3fs.open(file_path, "rb") as file:
|
||||
return classify(file)
|
||||
else:
|
||||
raise ValueError("S3 file path provided but no s3fs available")
|
||||
else:
|
||||
# Handle local files and file:// URLs
|
||||
local_path = file_path.replace("file://", "")
|
||||
with open(local_path, "rb") as file:
|
||||
return classify(file)
|
||||
|
||||
def is_text_content(self, data: Union[str, Any]) -> bool:
|
||||
"""
|
||||
Check if the provided data is text content (not a file path).
|
||||
|
||||
Args:
|
||||
data: The data to check
|
||||
|
||||
Returns:
|
||||
True if data is text content, False if it's a file path
|
||||
"""
|
||||
if not isinstance(data, str):
|
||||
return False
|
||||
|
||||
# Check if it's a file path
|
||||
if (
|
||||
data.startswith("/")
|
||||
or data.startswith("file://")
|
||||
or data.startswith("s3://")
|
||||
or (len(data) > 1 and data[1] == ":")
|
||||
): # Windows drive paths
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_text_ingestion_data(self, content: str) -> IngestionData:
|
||||
"""
|
||||
Create IngestionData for text content.
|
||||
|
||||
Args:
|
||||
content: Text content to wrap
|
||||
|
||||
Returns:
|
||||
IngestionData compatible object
|
||||
"""
|
||||
|
||||
return TextData(content)
|
||||
|
|
@ -60,7 +60,7 @@ async def ingest_data(
|
|||
else:
|
||||
# Find existing dataset or create a new one
|
||||
existing_datasets = await get_authorized_existing_datasets(
|
||||
user=user, permission_type="write", datasets=[dataset_name]
|
||||
datasets=[dataset_name], permission_type="write", user=user
|
||||
)
|
||||
dataset = await load_or_create_datasets(
|
||||
dataset_names=[dataset_name],
|
||||
|
|
|
|||
309
cognee/tasks/ingestion/plugin_ingest_data.py
Normal file
309
cognee/tasks/ingestion/plugin_ingest_data.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
import json
|
||||
import inspect
|
||||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, Any, List, Optional
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.data.methods import (
|
||||
get_authorized_existing_datasets,
|
||||
get_dataset_data,
|
||||
load_or_create_datasets,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .save_data_item_to_storage import save_data_item_to_storage
|
||||
from .adapters import LoaderToIngestionAdapter
|
||||
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def plugin_ingest_data(
|
||||
data: Any,
|
||||
dataset_name: str,
|
||||
user: User,
|
||||
node_set: Optional[List[str]] = None,
|
||||
dataset_id: UUID = None,
|
||||
preferred_loaders: Optional[List[str]] = None,
|
||||
loader_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Plugin-based data ingestion using the loader system.
|
||||
|
||||
This function maintains full backward compatibility with the existing
|
||||
ingest_data function while adding support for the new loader system.
|
||||
|
||||
Args:
|
||||
data: The data to ingest
|
||||
dataset_name: Name of the dataset
|
||||
user: User object for permissions
|
||||
node_set: Optional node set for organization
|
||||
dataset_id: Optional specific dataset ID
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
loader_config: Configuration for specific loaders
|
||||
|
||||
Returns:
|
||||
List of Data objects that were ingested
|
||||
"""
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
# Ensure NLTK data is downloaded (preserves automatic download behavior)
|
||||
def ensure_nltk_data():
|
||||
"""Download required NLTK data if not already present."""
|
||||
try:
|
||||
import nltk
|
||||
|
||||
# Download essential NLTK data used by the system
|
||||
nltk.download("punkt_tab", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
nltk.download("words", quiet=True)
|
||||
logger.info("NLTK data verified/downloaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download NLTK data: {e}")
|
||||
|
||||
# Download NLTK data once per session
|
||||
if not hasattr(plugin_ingest_data, "_nltk_initialized"):
|
||||
ensure_nltk_data()
|
||||
plugin_ingest_data._nltk_initialized = True
|
||||
|
||||
# Initialize S3 support (maintain existing behavior)
|
||||
s3_config = get_s3_config()
|
||||
fs = None
|
||||
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
||||
import s3fs
|
||||
|
||||
fs = s3fs.S3FileSystem(
|
||||
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
|
||||
)
|
||||
|
||||
# Initialize the loader adapter
|
||||
loader_adapter = LoaderToIngestionAdapter()
|
||||
|
||||
def open_data_file(file_path: str):
|
||||
"""Open file with S3 support (preserves existing behavior)."""
|
||||
if file_path.startswith("s3://"):
|
||||
return fs.open(file_path, mode="rb")
|
||||
else:
|
||||
local_path = file_path.replace("file://", "")
|
||||
return open(local_path, mode="rb")
|
||||
|
||||
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
|
||||
"""Get external metadata (preserves existing behavior)."""
|
||||
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def store_data_to_dataset(
|
||||
data: Any,
|
||||
dataset_name: str,
|
||||
user: User,
|
||||
node_set: Optional[List[str]] = None,
|
||||
dataset_id: UUID = None,
|
||||
):
|
||||
"""
|
||||
Core data storage logic with plugin-based file processing.
|
||||
|
||||
This function preserves all existing permission and database logic
|
||||
while using the new loader system for file processing.
|
||||
"""
|
||||
logger.info(f"Plugin-based ingestion starting for dataset: {dataset_name}")
|
||||
|
||||
# Preserve existing dataset creation and permission logic
|
||||
if dataset_id:
|
||||
# Retrieve existing dataset by ID
|
||||
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
|
||||
# Convert from list to Dataset element
|
||||
if isinstance(dataset, list):
|
||||
dataset = dataset[0]
|
||||
else:
|
||||
# Find existing dataset or create a new one by name
|
||||
existing_datasets = await get_authorized_existing_datasets(
|
||||
datasets=[dataset_name], permission_type="write", user=user
|
||||
)
|
||||
datasets = await load_or_create_datasets(
|
||||
dataset_names=[dataset_name], existing_datasets=existing_datasets, user=user
|
||||
)
|
||||
if isinstance(datasets, list):
|
||||
dataset = datasets[0]
|
||||
|
||||
new_datapoints = []
|
||||
existing_data_points = []
|
||||
dataset_new_data_points = []
|
||||
|
||||
# Get existing dataset data for deduplication (preserve existing logic)
|
||||
dataset_data: list[Data] = await get_dataset_data(dataset.id)
|
||||
dataset_data_map = {str(data.id): True for data in dataset_data}
|
||||
|
||||
for data_item in data:
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
|
||||
# NEW: Use loader system or existing classification based on data type
|
||||
try:
|
||||
if loader_adapter.is_text_content(data_item):
|
||||
# Handle text content (preserve existing behavior)
|
||||
logger.info("Processing text content with existing system")
|
||||
classified_data = ingestion.classify(data_item)
|
||||
else:
|
||||
# Use loader system for file paths
|
||||
logger.info(f"Processing file with loader system: {file_path}")
|
||||
classified_data = await loader_adapter.process_file_with_loaders(
|
||||
file_path,
|
||||
s3fs=fs,
|
||||
preferred_loaders=preferred_loaders,
|
||||
loader_config=loader_config,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Plugin system failed for {file_path}, falling back: {e}")
|
||||
# Fallback to existing system for full backward compatibility
|
||||
with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
# Preserve all existing data processing logic
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
file_metadata = classified_data.get_metadata()
|
||||
|
||||
# Ensure metadata has all required fields with fallbacks
|
||||
def get_metadata_field(metadata, field_name, default_value=""):
|
||||
"""Get metadata field with fallback handling."""
|
||||
if field_name in metadata and metadata[field_name] is not None:
|
||||
return metadata[field_name]
|
||||
|
||||
logger.warning(f"Missing metadata field '{field_name}', using fallback")
|
||||
|
||||
# Provide fallbacks based on available information
|
||||
if field_name == "name":
|
||||
if "file_path" in metadata and metadata["file_path"]:
|
||||
import os
|
||||
|
||||
return os.path.basename(str(metadata["file_path"])).split(".")[0]
|
||||
elif file_path:
|
||||
import os
|
||||
|
||||
return os.path.basename(str(file_path)).split(".")[0]
|
||||
else:
|
||||
content_hash = metadata.get("content_hash", str(data_id))[:8]
|
||||
return f"content_{content_hash}"
|
||||
elif field_name == "file_path":
|
||||
# Use the actual file path returned by save_data_item_to_storage
|
||||
return file_path
|
||||
elif field_name == "extension":
|
||||
if "file_path" in metadata and metadata["file_path"]:
|
||||
import os
|
||||
|
||||
_, ext = os.path.splitext(str(metadata["file_path"]))
|
||||
return ext.lstrip(".") if ext else "txt"
|
||||
elif file_path:
|
||||
import os
|
||||
|
||||
_, ext = os.path.splitext(str(file_path))
|
||||
return ext.lstrip(".") if ext else "txt"
|
||||
return "txt"
|
||||
elif field_name == "mime_type":
|
||||
ext = get_metadata_field(metadata, "extension", "txt")
|
||||
mime_map = {
|
||||
"txt": "text/plain",
|
||||
"md": "text/markdown",
|
||||
"pdf": "application/pdf",
|
||||
"json": "application/json",
|
||||
"csv": "text/csv",
|
||||
}
|
||||
return mime_map.get(ext.lower(), "text/plain")
|
||||
elif field_name == "content_hash":
|
||||
# Extract the raw content hash for compatibility with deletion system
|
||||
content_identifier = classified_data.get_identifier()
|
||||
# Remove content type prefix if present (e.g., "text_abc123" -> "abc123")
|
||||
if "_" in content_identifier:
|
||||
return content_identifier.split("_", 1)[-1]
|
||||
return content_identifier
|
||||
|
||||
return default_value
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
# Check if data should be updated (preserve existing logic)
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
ext_metadata = get_external_metadata_dict(data_item)
|
||||
|
||||
if node_set:
|
||||
ext_metadata["node_set"] = node_set
|
||||
|
||||
# Preserve existing data point creation/update logic
|
||||
if data_point is not None:
|
||||
data_point.name = get_metadata_field(file_metadata, "name")
|
||||
data_point.raw_data_location = get_metadata_field(file_metadata, "file_path")
|
||||
data_point.extension = get_metadata_field(file_metadata, "extension")
|
||||
data_point.mime_type = get_metadata_field(file_metadata, "mime_type")
|
||||
data_point.owner_id = user.id
|
||||
data_point.content_hash = get_metadata_field(file_metadata, "content_hash")
|
||||
data_point.external_metadata = ext_metadata
|
||||
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||
|
||||
if str(data_point.id) in dataset_data_map:
|
||||
existing_data_points.append(data_point)
|
||||
else:
|
||||
dataset_new_data_points.append(data_point)
|
||||
dataset_data_map[str(data_point.id)] = True
|
||||
else:
|
||||
if str(data_id) in dataset_data_map:
|
||||
continue
|
||||
|
||||
data_point = Data(
|
||||
id=data_id,
|
||||
name=get_metadata_field(file_metadata, "name"),
|
||||
raw_data_location=get_metadata_field(file_metadata, "file_path"),
|
||||
extension=get_metadata_field(file_metadata, "extension"),
|
||||
mime_type=get_metadata_field(file_metadata, "mime_type"),
|
||||
owner_id=user.id,
|
||||
content_hash=get_metadata_field(file_metadata, "content_hash"),
|
||||
external_metadata=ext_metadata,
|
||||
node_set=json.dumps(node_set) if node_set else None,
|
||||
token_count=-1,
|
||||
)
|
||||
|
||||
new_datapoints.append(data_point)
|
||||
dataset_data_map[str(data_point.id)] = True
|
||||
|
||||
# Preserve existing database operations
|
||||
async with db_engine.get_async_session() as session:
|
||||
if dataset not in session:
|
||||
session.add(dataset)
|
||||
|
||||
if len(new_datapoints) > 0:
|
||||
dataset.data.extend(new_datapoints)
|
||||
|
||||
if len(existing_data_points) > 0:
|
||||
for data_point in existing_data_points:
|
||||
await session.merge(data_point)
|
||||
|
||||
if len(dataset_new_data_points) > 0:
|
||||
dataset.data.extend(dataset_new_data_points)
|
||||
|
||||
await session.merge(dataset)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Plugin-based ingestion completed. New: {len(new_datapoints)}, "
|
||||
+ f"Updated: {len(existing_data_points)}, Dataset new: {len(dataset_new_data_points)}"
|
||||
)
|
||||
|
||||
return existing_data_points + dataset_new_data_points + new_datapoints
|
||||
|
||||
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
||||
237
infrastructure/loaders/LoaderEngine.py
Normal file
237
infrastructure/loaders/LoaderEngine.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import os
|
||||
import importlib.util
|
||||
from typing import Dict, List, Optional
|
||||
from .LoaderInterface import LoaderInterface
|
||||
from .models.LoaderResult import LoaderResult
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
class LoaderEngine:
|
||||
"""
|
||||
Main loader engine for managing file loaders.
|
||||
|
||||
Follows cognee's adapter pattern similar to database engines,
|
||||
providing a centralized system for file loading operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader_directories: List[str],
|
||||
default_loader_priority: List[str],
|
||||
fallback_loader: str = "text_loader",
|
||||
enable_dependency_validation: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the loader engine.
|
||||
|
||||
Args:
|
||||
loader_directories: Directories to search for loader implementations
|
||||
default_loader_priority: Priority order for loader selection
|
||||
fallback_loader: Default loader to use when no other matches
|
||||
enable_dependency_validation: Whether to validate loader dependencies
|
||||
"""
|
||||
self._loaders: Dict[str, LoaderInterface] = {}
|
||||
self._extension_map: Dict[str, List[LoaderInterface]] = {}
|
||||
self._mime_type_map: Dict[str, List[LoaderInterface]] = {}
|
||||
self.loader_directories = loader_directories
|
||||
self.default_loader_priority = default_loader_priority
|
||||
self.fallback_loader = fallback_loader
|
||||
self.enable_dependency_validation = enable_dependency_validation
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||
"""
|
||||
Register a loader with the engine.
|
||||
|
||||
Args:
|
||||
loader: LoaderInterface implementation to register
|
||||
|
||||
Returns:
|
||||
True if loader was registered successfully, False otherwise
|
||||
"""
|
||||
# Validate dependencies if enabled
|
||||
if self.enable_dependency_validation and not loader.validate_dependencies():
|
||||
self.logger.warning(
|
||||
f"Skipping loader '{loader.loader_name}' - missing dependencies: "
|
||||
f"{loader.get_dependencies()}"
|
||||
)
|
||||
return False
|
||||
|
||||
self._loaders[loader.loader_name] = loader
|
||||
|
||||
# Map extensions to loaders
|
||||
for ext in loader.supported_extensions:
|
||||
ext_lower = ext.lower()
|
||||
if ext_lower not in self._extension_map:
|
||||
self._extension_map[ext_lower] = []
|
||||
self._extension_map[ext_lower].append(loader)
|
||||
|
||||
# Map mime types to loaders
|
||||
for mime_type in loader.supported_mime_types:
|
||||
if mime_type not in self._mime_type_map:
|
||||
self._mime_type_map[mime_type] = []
|
||||
self._mime_type_map[mime_type].append(loader)
|
||||
|
||||
self.logger.info(f"Registered loader: {loader.loader_name}")
|
||||
return True
|
||||
|
||||
def get_loader(
|
||||
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None
|
||||
) -> Optional[LoaderInterface]:
|
||||
"""
|
||||
Get appropriate loader for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
|
||||
Returns:
|
||||
LoaderInterface that can handle the file, or None if not found
|
||||
"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# Try preferred loaders first
|
||||
if preferred_loaders:
|
||||
for loader_name in preferred_loaders:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try priority order
|
||||
for loader_name in self.default_loader_priority:
|
||||
if loader_name in self._loaders:
|
||||
loader = self._loaders[loader_name]
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try mime type mapping
|
||||
if mime_type and mime_type in self._mime_type_map:
|
||||
for loader in self._mime_type_map[mime_type]:
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Try extension mapping
|
||||
if ext in self._extension_map:
|
||||
for loader in self._extension_map[ext]:
|
||||
if loader.can_handle(file_path, mime_type):
|
||||
return loader
|
||||
|
||||
# Fallback loader
|
||||
if self.fallback_loader in self._loaders:
|
||||
fallback = self._loaders[self.fallback_loader]
|
||||
if fallback.can_handle(file_path, mime_type):
|
||||
return fallback
|
||||
|
||||
return None
|
||||
|
||||
async def load_file(
|
||||
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None, **kwargs
|
||||
) -> LoaderResult:
|
||||
"""
|
||||
Load file using appropriate loader.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
preferred_loaders: List of preferred loader names to try first
|
||||
**kwargs: Additional loader-specific configuration
|
||||
|
||||
Returns:
|
||||
LoaderResult containing processed content and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no suitable loader is found
|
||||
Exception: If file processing fails
|
||||
"""
|
||||
loader = self.get_loader(file_path, mime_type, preferred_loaders)
|
||||
if not loader:
|
||||
raise ValueError(f"No loader found for file: {file_path}")
|
||||
|
||||
self.logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
||||
return await loader.load(file_path, **kwargs)
|
||||
|
||||
def discover_loaders(self):
|
||||
"""
|
||||
Auto-discover loaders from configured directories.
|
||||
|
||||
Scans loader directories for Python modules containing
|
||||
LoaderInterface implementations and registers them.
|
||||
"""
|
||||
for directory in self.loader_directories:
|
||||
if os.path.exists(directory):
|
||||
self._discover_in_directory(directory)
|
||||
|
||||
def _discover_in_directory(self, directory: str):
|
||||
"""
|
||||
Discover loaders in a specific directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to scan for loader implementations
|
||||
"""
|
||||
try:
|
||||
for file_name in os.listdir(directory):
|
||||
if file_name.endswith(".py") and not file_name.startswith("_"):
|
||||
module_name = file_name[:-3]
|
||||
file_path = os.path.join(directory, file_name)
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Look for loader classes
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(attr, type)
|
||||
and issubclass(attr, LoaderInterface)
|
||||
and attr != LoaderInterface
|
||||
):
|
||||
# Instantiate and register the loader
|
||||
try:
|
||||
loader_instance = attr()
|
||||
self.register_loader(loader_instance)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to instantiate loader {attr_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to load module {module_name}: {e}")
|
||||
|
||||
except OSError as e:
|
||||
self.logger.warning(f"Failed to scan directory {directory}: {e}")
|
||||
|
||||
def get_available_loaders(self) -> List[str]:
|
||||
"""
|
||||
Get list of available loader names.
|
||||
|
||||
Returns:
|
||||
List of registered loader names
|
||||
"""
|
||||
return list(self._loaders.keys())
|
||||
|
||||
def get_loader_info(self, loader_name: str) -> Dict[str, any]:
|
||||
"""
|
||||
Get information about a specific loader.
|
||||
|
||||
Args:
|
||||
loader_name: Name of the loader to inspect
|
||||
|
||||
Returns:
|
||||
Dictionary containing loader information
|
||||
"""
|
||||
if loader_name not in self._loaders:
|
||||
return {}
|
||||
|
||||
loader = self._loaders[loader_name]
|
||||
return {
|
||||
"name": loader.loader_name,
|
||||
"extensions": loader.supported_extensions,
|
||||
"mime_types": loader.supported_mime_types,
|
||||
"dependencies": loader.get_dependencies(),
|
||||
"available": loader.validate_dependencies(),
|
||||
}
|
||||
101
infrastructure/loaders/LoaderInterface.py
Normal file
101
infrastructure/loaders/LoaderInterface.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from .models.LoaderResult import LoaderResult
|
||||
|
||||
|
||||
class LoaderInterface(ABC):
|
||||
"""
|
||||
Base interface for all file loaders in cognee.
|
||||
|
||||
This interface follows cognee's established pattern for database adapters,
|
||||
ensuring consistent behavior across all loader implementations.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""
|
||||
List of file extensions this loader supports.
|
||||
|
||||
Returns:
|
||||
List of extensions including the dot (e.g., ['.txt', '.md'])
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""
|
||||
List of MIME types this loader supports.
|
||||
|
||||
Returns:
|
||||
List of MIME type strings (e.g., ['text/plain', 'application/pdf'])
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def loader_name(self) -> str:
|
||||
"""
|
||||
Unique name identifier for this loader.
|
||||
|
||||
Returns:
|
||||
String identifier used for registration and configuration
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
mime_type: Optional MIME type of the file
|
||||
|
||||
Returns:
|
||||
True if this loader can process the file, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process the file, returning standardized result.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to be processed
|
||||
**kwargs: Additional loader-specific configuration
|
||||
|
||||
Returns:
|
||||
LoaderResult containing processed content and metadata
|
||||
|
||||
Raises:
|
||||
Exception: If file cannot be processed
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_dependencies(self) -> List[str]:
|
||||
"""
|
||||
Optional: Return list of required dependencies for this loader.
|
||||
|
||||
Returns:
|
||||
List of package names with optional version specifications
|
||||
"""
|
||||
return []
|
||||
|
||||
def validate_dependencies(self) -> bool:
|
||||
"""
|
||||
Check if all required dependencies are available.
|
||||
|
||||
Returns:
|
||||
True if all dependencies are installed, False otherwise
|
||||
"""
|
||||
for dep in self.get_dependencies():
|
||||
# Extract package name from version specification
|
||||
package_name = dep.split(">=")[0].split("==")[0].split("<")[0]
|
||||
try:
|
||||
__import__(package_name)
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
19
infrastructure/loaders/__init__.py
Normal file
19
infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
File loader infrastructure for cognee.
|
||||
|
||||
This package provides a plugin-based system for loading different file formats
|
||||
into cognee, following the same patterns as database adapters.
|
||||
|
||||
Main exports:
|
||||
- get_loader_engine(): Factory function to get configured loader engine
|
||||
- use_loader(): Register custom loaders at runtime
|
||||
- LoaderInterface: Base interface for implementing loaders
|
||||
- LoaderResult, ContentType: Data models for loader results
|
||||
"""
|
||||
|
||||
from .get_loader_engine import get_loader_engine
|
||||
from .use_loader import use_loader
|
||||
from .LoaderInterface import LoaderInterface
|
||||
from .models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
__all__ = ["get_loader_engine", "use_loader", "LoaderInterface", "LoaderResult", "ContentType"]
|
||||
57
infrastructure/loaders/config.py
Normal file
57
infrastructure/loaders/config.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
||||
class LoaderConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for file loader system.
|
||||
|
||||
Follows cognee's pattern using pydantic_settings.BaseSettings for
|
||||
environment variable support and validation.
|
||||
"""
|
||||
|
||||
loader_directories: List[str] = [
|
||||
get_absolute_path("cognee/infrastructure/loaders/core"),
|
||||
get_absolute_path("cognee/infrastructure/loaders/external"),
|
||||
]
|
||||
default_loader_priority: List[str] = [
|
||||
"text_loader",
|
||||
"pypdf_loader",
|
||||
"unstructured_loader",
|
||||
"dlt_loader",
|
||||
]
|
||||
auto_discover: bool = True
|
||||
fallback_loader: str = "text_loader"
|
||||
enable_dependency_validation: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow", env_prefix="LOADER_")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert configuration to dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict containing all loader configuration settings
|
||||
"""
|
||||
return {
|
||||
"loader_directories": self.loader_directories,
|
||||
"default_loader_priority": self.default_loader_priority,
|
||||
"auto_discover": self.auto_discover,
|
||||
"fallback_loader": self.fallback_loader,
|
||||
"enable_dependency_validation": self.enable_dependency_validation,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_loader_config() -> LoaderConfig:
|
||||
"""
|
||||
Get cached loader configuration.
|
||||
|
||||
Uses LRU cache following cognee's pattern for configuration objects.
|
||||
|
||||
Returns:
|
||||
LoaderConfig instance with current settings
|
||||
"""
|
||||
return LoaderConfig()
|
||||
5
infrastructure/loaders/core/__init__.py
Normal file
5
infrastructure/loaders/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Core loader implementations that are always available."""
|
||||
|
||||
from .text_loader import TextLoader
|
||||
|
||||
__all__ = ["TextLoader"]
|
||||
128
infrastructure/loaders/core/text_loader.py
Normal file
128
infrastructure/loaders/core/text_loader.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
import os
|
||||
from typing import List
|
||||
from ..LoaderInterface import LoaderInterface
|
||||
from ..models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
|
||||
class TextLoader(LoaderInterface):
|
||||
"""
|
||||
Core text file loader that handles basic text file formats.
|
||||
|
||||
This loader is always available and serves as the fallback for
|
||||
text-based files when no specialized loader is available.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", ".log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
"text/yaml",
|
||||
"application/yaml",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "text_loader"
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
# Check by extension
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self.supported_extensions:
|
||||
return True
|
||||
|
||||
# Check by MIME type
|
||||
if mime_type and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
# As fallback loader, can attempt to handle any text-like file
|
||||
# This is useful when other loaders fail
|
||||
try:
|
||||
# Quick check if file appears to be text
|
||||
with open(file_path, "rb") as f:
|
||||
sample = f.read(512)
|
||||
# Simple heuristic: if most bytes are printable, consider it text
|
||||
if sample:
|
||||
try:
|
||||
sample.decode("utf-8")
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
sample.decode("latin-1")
|
||||
return True
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs) -> LoaderResult:
|
||||
"""
|
||||
Load and process the text file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding=encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with fallback encoding
|
||||
if encoding == "utf-8":
|
||||
return await self.load(file_path, encoding="latin-1", **kwargs)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Extract basic metadata
|
||||
file_stat = os.stat(file_path)
|
||||
metadata = {
|
||||
"name": os.path.basename(file_path),
|
||||
"size": file_stat.st_size,
|
||||
"extension": os.path.splitext(file_path)[1],
|
||||
"encoding": encoding,
|
||||
"loader": self.loader_name,
|
||||
"lines": len(content.splitlines()) if content else 0,
|
||||
"characters": len(content),
|
||||
}
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
content_type=ContentType.TEXT,
|
||||
source_info={"file_path": file_path, "encoding": encoding},
|
||||
)
|
||||
49
infrastructure/loaders/create_loader_engine.py
Normal file
49
infrastructure/loaders/create_loader_engine.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from typing import List
|
||||
from .LoaderEngine import LoaderEngine
|
||||
from .supported_loaders import supported_loaders
|
||||
|
||||
|
||||
def create_loader_engine(
|
||||
loader_directories: List[str],
|
||||
default_loader_priority: List[str],
|
||||
auto_discover: bool = True,
|
||||
fallback_loader: str = "text_loader",
|
||||
enable_dependency_validation: bool = True,
|
||||
) -> LoaderEngine:
|
||||
"""
|
||||
Create loader engine with given configuration.
|
||||
|
||||
Follows cognee's pattern for engine creation functions used
|
||||
in database adapters.
|
||||
|
||||
Args:
|
||||
loader_directories: Directories to search for loader implementations
|
||||
default_loader_priority: Priority order for loader selection
|
||||
auto_discover: Whether to auto-discover loaders from directories
|
||||
fallback_loader: Default loader to use when no other matches
|
||||
enable_dependency_validation: Whether to validate loader dependencies
|
||||
|
||||
Returns:
|
||||
Configured LoaderEngine instance
|
||||
"""
|
||||
engine = LoaderEngine(
|
||||
loader_directories=loader_directories,
|
||||
default_loader_priority=default_loader_priority,
|
||||
fallback_loader=fallback_loader,
|
||||
enable_dependency_validation=enable_dependency_validation,
|
||||
)
|
||||
|
||||
# Register supported loaders from registry
|
||||
for loader_name, loader_class in supported_loaders.items():
|
||||
try:
|
||||
loader_instance = loader_class()
|
||||
engine.register_loader(loader_instance)
|
||||
except Exception as e:
|
||||
# Log but don't fail - allow engine to continue with other loaders
|
||||
engine.logger.warning(f"Failed to register loader {loader_name}: {e}")
|
||||
|
||||
# Auto-discover loaders if enabled
|
||||
if auto_discover:
|
||||
engine.discover_loaders()
|
||||
|
||||
return engine
|
||||
20
infrastructure/loaders/get_loader_engine.py
Normal file
20
infrastructure/loaders/get_loader_engine.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from functools import lru_cache
|
||||
from .config import get_loader_config
|
||||
from .LoaderEngine import LoaderEngine
|
||||
from .create_loader_engine import create_loader_engine
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_loader_engine() -> LoaderEngine:
|
||||
"""
|
||||
Factory function to get loader engine.
|
||||
|
||||
Follows cognee's pattern with @lru_cache for efficient reuse
|
||||
of engine instances. Configuration is loaded from environment
|
||||
variables and settings.
|
||||
|
||||
Returns:
|
||||
Cached LoaderEngine instance configured with current settings
|
||||
"""
|
||||
config = get_loader_config()
|
||||
return create_loader_engine(**config.to_dict())
|
||||
47
infrastructure/loaders/models/LoaderResult.py
Normal file
47
infrastructure/loaders/models/LoaderResult.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Content type classification for loaded files"""
|
||||
|
||||
TEXT = "text"
|
||||
STRUCTURED = "structured"
|
||||
BINARY = "binary"
|
||||
|
||||
|
||||
class LoaderResult(BaseModel):
|
||||
"""
|
||||
Standardized output format for all file loaders.
|
||||
|
||||
This model ensures consistent data structure across all loader implementations,
|
||||
following cognee's pattern of using Pydantic models for data validation.
|
||||
"""
|
||||
|
||||
content: str # Primary text content extracted from file
|
||||
metadata: Dict[str, Any] # File metadata (name, size, type, loader info, etc.)
|
||||
content_type: ContentType # Content classification
|
||||
chunks: Optional[List[str]] = None # Pre-chunked content if available
|
||||
source_info: Optional[Dict[str, Any]] = None # Source-specific information
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the loader result to a dictionary format.
|
||||
|
||||
Returns:
|
||||
Dict containing all loader result data with string-serialized content_type
|
||||
"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"metadata": self.metadata,
|
||||
"content_type": self.content_type.value,
|
||||
"source_info": self.source_info or {},
|
||||
"chunks": self.chunks,
|
||||
}
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration following cognee patterns"""
|
||||
|
||||
use_enum_values = True
|
||||
validate_assignment = True
|
||||
3
infrastructure/loaders/models/__init__.py
Normal file
3
infrastructure/loaders/models/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .LoaderResult import LoaderResult, ContentType
|
||||
|
||||
__all__ = ["LoaderResult", "ContentType"]
|
||||
3
infrastructure/loaders/supported_loaders.py
Normal file
3
infrastructure/loaders/supported_loaders.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Registry for loader implementations
|
||||
# Follows cognee's pattern used in databases/vector/supported_databases.py
|
||||
supported_loaders = {}
|
||||
22
infrastructure/loaders/use_loader.py
Normal file
22
infrastructure/loaders/use_loader.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from .supported_loaders import supported_loaders
|
||||
|
||||
|
||||
def use_loader(loader_name: str, loader_class):
|
||||
"""
|
||||
Register a loader at runtime.
|
||||
|
||||
Follows cognee's pattern used in databases for adapter registration.
|
||||
This allows external packages and custom loaders to be registered
|
||||
into the loader system.
|
||||
|
||||
Args:
|
||||
loader_name: Unique name for the loader
|
||||
loader_class: Loader class implementing LoaderInterface
|
||||
|
||||
Example:
|
||||
from cognee.infrastructure.loaders import use_loader
|
||||
from my_package import MyCustomLoader
|
||||
|
||||
use_loader("my_custom_loader", MyCustomLoader)
|
||||
"""
|
||||
supported_loaders[loader_name] = loader_class
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Tests package
|
||||
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Unit tests package
|
||||
1
tests/unit/infrastructure/__init__.py
Normal file
1
tests/unit/infrastructure/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Infrastructure tests package
|
||||
1
tests/unit/infrastructure/loaders/__init__.py
Normal file
1
tests/unit/infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Loaders tests package
|
||||
252
tests/unit/infrastructure/loaders/test_loader_engine.py
Normal file
252
tests/unit/infrastructure/loaders/test_loader_engine.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from cognee.infrastructure.loaders.LoaderEngine import LoaderEngine
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
|
||||
class MockLoader(LoaderInterface):
|
||||
"""Mock loader for testing."""
|
||||
|
||||
def __init__(self, name="mock_loader", extensions=None, mime_types=None, fail_deps=False):
|
||||
self._name = name
|
||||
self._extensions = extensions or [".mock"]
|
||||
self._mime_types = mime_types or ["application/mock"]
|
||||
self._fail_deps = fail_deps
|
||||
|
||||
@property
|
||||
def supported_extensions(self):
|
||||
return self._extensions
|
||||
|
||||
@property
|
||||
def supported_mime_types(self):
|
||||
return self._mime_types
|
||||
|
||||
@property
|
||||
def loader_name(self):
|
||||
return self._name
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
return ext in self._extensions or mime_type in self._mime_types
|
||||
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(
|
||||
content=f"Mock content from {self._name}",
|
||||
metadata={"loader": self._name, "name": os.path.basename(file_path)},
|
||||
content_type=ContentType.TEXT,
|
||||
)
|
||||
|
||||
def validate_dependencies(self) -> bool:
|
||||
return not self._fail_deps
|
||||
|
||||
|
||||
class TestLoaderEngine:
|
||||
"""Test the LoaderEngine class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
"""Create a LoaderEngine instance for testing."""
|
||||
return LoaderEngine(
|
||||
loader_directories=[],
|
||||
default_loader_priority=["loader1", "loader2"],
|
||||
fallback_loader="fallback",
|
||||
enable_dependency_validation=True,
|
||||
)
|
||||
|
||||
def test_engine_initialization(self, engine):
|
||||
"""Test LoaderEngine initialization."""
|
||||
assert engine.loader_directories == []
|
||||
assert engine.default_loader_priority == ["loader1", "loader2"]
|
||||
assert engine.fallback_loader == "fallback"
|
||||
assert engine.enable_dependency_validation is True
|
||||
assert len(engine.get_available_loaders()) == 0
|
||||
|
||||
def test_register_loader_success(self, engine):
|
||||
"""Test successful loader registration."""
|
||||
loader = MockLoader("test_loader", [".test"])
|
||||
|
||||
success = engine.register_loader(loader)
|
||||
|
||||
assert success is True
|
||||
assert "test_loader" in engine.get_available_loaders()
|
||||
assert engine._loaders["test_loader"] == loader
|
||||
assert ".test" in engine._extension_map
|
||||
assert "application/mock" in engine._mime_type_map
|
||||
|
||||
def test_register_loader_with_failed_dependencies(self, engine):
|
||||
"""Test loader registration with failed dependency validation."""
|
||||
loader = MockLoader("test_loader", [".test"], fail_deps=True)
|
||||
|
||||
success = engine.register_loader(loader)
|
||||
|
||||
assert success is False
|
||||
assert "test_loader" not in engine.get_available_loaders()
|
||||
|
||||
def test_register_loader_without_dependency_validation(self):
|
||||
"""Test loader registration without dependency validation."""
|
||||
engine = LoaderEngine(
|
||||
loader_directories=[], default_loader_priority=[], enable_dependency_validation=False
|
||||
)
|
||||
loader = MockLoader("test_loader", [".test"], fail_deps=True)
|
||||
|
||||
success = engine.register_loader(loader)
|
||||
|
||||
assert success is True
|
||||
assert "test_loader" in engine.get_available_loaders()
|
||||
|
||||
def test_get_loader_by_extension(self, engine):
|
||||
"""Test getting loader by file extension."""
|
||||
loader1 = MockLoader("loader1", [".txt"])
|
||||
loader2 = MockLoader("loader2", [".pdf"])
|
||||
|
||||
engine.register_loader(loader1)
|
||||
engine.register_loader(loader2)
|
||||
|
||||
result = engine.get_loader("test.txt")
|
||||
assert result == loader1
|
||||
|
||||
result = engine.get_loader("test.pdf")
|
||||
assert result == loader2
|
||||
|
||||
result = engine.get_loader("test.unknown")
|
||||
assert result is None
|
||||
|
||||
def test_get_loader_by_mime_type(self, engine):
|
||||
"""Test getting loader by MIME type."""
|
||||
loader = MockLoader("loader", [".txt"], ["text/plain"])
|
||||
engine.register_loader(loader)
|
||||
|
||||
result = engine.get_loader("test.unknown", mime_type="text/plain")
|
||||
assert result == loader
|
||||
|
||||
result = engine.get_loader("test.unknown", mime_type="application/pdf")
|
||||
assert result is None
|
||||
|
||||
def test_get_loader_with_preferences(self, engine):
|
||||
"""Test getting loader with preferred loaders."""
|
||||
loader1 = MockLoader("loader1", [".txt"])
|
||||
loader2 = MockLoader("loader2", [".txt"])
|
||||
|
||||
engine.register_loader(loader1)
|
||||
engine.register_loader(loader2)
|
||||
|
||||
# Should get preferred loader
|
||||
result = engine.get_loader("test.txt", preferred_loaders=["loader2"])
|
||||
assert result == loader2
|
||||
|
||||
# Should fallback to first available if preferred not found
|
||||
result = engine.get_loader("test.txt", preferred_loaders=["nonexistent"])
|
||||
assert result in [loader1, loader2] # One of them should be returned
|
||||
|
||||
def test_get_loader_with_priority(self, engine):
|
||||
"""Test loader selection with priority order."""
|
||||
engine.default_loader_priority = ["priority_loader", "other_loader"]
|
||||
|
||||
priority_loader = MockLoader("priority_loader", [".txt"])
|
||||
other_loader = MockLoader("other_loader", [".txt"])
|
||||
|
||||
# Register in reverse order
|
||||
engine.register_loader(other_loader)
|
||||
engine.register_loader(priority_loader)
|
||||
|
||||
# Should get priority loader even though other was registered first
|
||||
result = engine.get_loader("test.txt")
|
||||
assert result == priority_loader
|
||||
|
||||
def test_get_loader_fallback(self, engine):
|
||||
"""Test fallback loader selection."""
|
||||
fallback_loader = MockLoader("fallback", [".txt"])
|
||||
other_loader = MockLoader("other", [".pdf"])
|
||||
|
||||
engine.register_loader(fallback_loader)
|
||||
engine.register_loader(other_loader)
|
||||
engine.fallback_loader = "fallback"
|
||||
|
||||
# For .txt file, fallback should be considered
|
||||
result = engine.get_loader("test.txt")
|
||||
assert result == fallback_loader
|
||||
|
||||
# For unknown extension, should still get fallback if it can handle
|
||||
result = engine.get_loader("test.unknown")
|
||||
assert result == fallback_loader
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_file_success(self, engine):
|
||||
"""Test successful file loading."""
|
||||
loader = MockLoader("test_loader", [".txt"])
|
||||
engine.register_loader(loader)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"test content")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = await engine.load_file(temp_path)
|
||||
assert result.content == "Mock content from test_loader"
|
||||
assert result.metadata["loader"] == "test_loader"
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_file_no_loader(self, engine):
|
||||
"""Test file loading when no suitable loader is found."""
|
||||
with pytest.raises(ValueError, match="No loader found for file"):
|
||||
await engine.load_file("test.unknown")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_file_with_preferences(self, engine):
|
||||
"""Test file loading with preferred loaders."""
|
||||
loader1 = MockLoader("loader1", [".txt"])
|
||||
loader2 = MockLoader("loader2", [".txt"])
|
||||
|
||||
engine.register_loader(loader1)
|
||||
engine.register_loader(loader2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"test content")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = await engine.load_file(temp_path, preferred_loaders=["loader2"])
|
||||
assert result.metadata["loader"] == "loader2"
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_get_loader_info(self, engine):
|
||||
"""Test getting loader information."""
|
||||
loader = MockLoader("test_loader", [".txt"], ["text/plain"])
|
||||
engine.register_loader(loader)
|
||||
|
||||
info = engine.get_loader_info("test_loader")
|
||||
|
||||
assert info["name"] == "test_loader"
|
||||
assert info["extensions"] == [".txt"]
|
||||
assert info["mime_types"] == ["text/plain"]
|
||||
assert info["available"] is True
|
||||
|
||||
# Test non-existent loader
|
||||
info = engine.get_loader_info("nonexistent")
|
||||
assert info == {}
|
||||
|
||||
def test_discover_loaders_empty_directory(self, engine):
|
||||
"""Test loader discovery with empty directory."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
engine.loader_directories = [temp_dir]
|
||||
engine.discover_loaders()
|
||||
|
||||
# Should not find any loaders in empty directory
|
||||
assert len(engine.get_available_loaders()) == 0
|
||||
|
||||
def test_discover_loaders_nonexistent_directory(self, engine):
|
||||
"""Test loader discovery with non-existent directory."""
|
||||
engine.loader_directories = ["/nonexistent/directory"]
|
||||
|
||||
# Should not raise exception, just log warning
|
||||
engine.discover_loaders()
|
||||
assert len(engine.get_available_loaders()) == 0
|
||||
99
tests/unit/infrastructure/loaders/test_loader_interface.py
Normal file
99
tests/unit/infrastructure/loaders/test_loader_interface.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||
|
||||
|
||||
class TestLoaderInterface:
|
||||
"""Test the LoaderInterface abstract base class."""
|
||||
|
||||
def test_loader_interface_is_abstract(self):
|
||||
"""Test that LoaderInterface cannot be instantiated directly."""
|
||||
with pytest.raises(TypeError):
|
||||
LoaderInterface()
|
||||
|
||||
def test_dependency_validation_with_no_dependencies(self):
|
||||
"""Test dependency validation when no dependencies are required."""
|
||||
|
||||
class MockLoader(LoaderInterface):
|
||||
@property
|
||||
def supported_extensions(self):
|
||||
return [".txt"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self):
|
||||
return ["text/plain"]
|
||||
|
||||
@property
|
||||
def loader_name(self):
|
||||
return "mock_loader"
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
return True
|
||||
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||
|
||||
loader = MockLoader()
|
||||
assert loader.validate_dependencies() is True
|
||||
assert loader.get_dependencies() == []
|
||||
|
||||
def test_dependency_validation_with_missing_dependencies(self):
|
||||
"""Test dependency validation with missing dependencies."""
|
||||
|
||||
class MockLoaderWithDeps(LoaderInterface):
|
||||
@property
|
||||
def supported_extensions(self):
|
||||
return [".txt"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self):
|
||||
return ["text/plain"]
|
||||
|
||||
@property
|
||||
def loader_name(self):
|
||||
return "mock_loader_deps"
|
||||
|
||||
def get_dependencies(self):
|
||||
return ["non_existent_package>=1.0.0"]
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
return True
|
||||
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||
|
||||
loader = MockLoaderWithDeps()
|
||||
assert loader.validate_dependencies() is False
|
||||
assert "non_existent_package>=1.0.0" in loader.get_dependencies()
|
||||
|
||||
def test_dependency_validation_with_existing_dependencies(self):
|
||||
"""Test dependency validation with existing dependencies."""
|
||||
|
||||
class MockLoaderWithExistingDeps(LoaderInterface):
|
||||
@property
|
||||
def supported_extensions(self):
|
||||
return [".txt"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self):
|
||||
return ["text/plain"]
|
||||
|
||||
@property
|
||||
def loader_name(self):
|
||||
return "mock_loader_existing"
|
||||
|
||||
def get_dependencies(self):
|
||||
return ["os"] # Built-in module that always exists
|
||||
|
||||
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||
return True
|
||||
|
||||
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||
|
||||
loader = MockLoaderWithExistingDeps()
|
||||
assert loader.validate_dependencies() is True
|
||||
157
tests/unit/infrastructure/loaders/test_text_loader.py
Normal file
157
tests/unit/infrastructure/loaders/test_text_loader.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.infrastructure.loaders.core.text_loader import TextLoader
|
||||
from cognee.infrastructure.loaders.models.LoaderResult import ContentType
|
||||
|
||||
|
||||
class TestTextLoader:
|
||||
"""Test the TextLoader implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def text_loader(self):
|
||||
"""Create a TextLoader instance for testing."""
|
||||
return TextLoader()
|
||||
|
||||
@pytest.fixture
|
||||
def temp_text_file(self):
|
||||
"""Create a temporary text file for testing."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
f.write("This is a test file.\nIt has multiple lines.\n")
|
||||
temp_path = f.name
|
||||
|
||||
yield temp_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
@pytest.fixture
|
||||
def temp_binary_file(self):
|
||||
"""Create a temporary binary file for testing."""
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".bin", delete=False) as f:
|
||||
f.write(b"\x00\x01\x02\x03\x04\x05")
|
||||
temp_path = f.name
|
||||
|
||||
yield temp_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_loader_properties(self, text_loader):
|
||||
"""Test basic loader properties."""
|
||||
assert text_loader.loader_name == "text_loader"
|
||||
assert ".txt" in text_loader.supported_extensions
|
||||
assert ".md" in text_loader.supported_extensions
|
||||
assert "text/plain" in text_loader.supported_mime_types
|
||||
assert "application/json" in text_loader.supported_mime_types
|
||||
|
||||
def test_can_handle_by_extension(self, text_loader):
|
||||
"""Test file handling by extension."""
|
||||
assert text_loader.can_handle("test.txt")
|
||||
assert text_loader.can_handle("test.md")
|
||||
assert text_loader.can_handle("test.json")
|
||||
assert text_loader.can_handle("test.TXT") # Case insensitive
|
||||
assert not text_loader.can_handle("test.pdf")
|
||||
|
||||
def test_can_handle_by_mime_type(self, text_loader):
|
||||
"""Test file handling by MIME type."""
|
||||
assert text_loader.can_handle("test.unknown", mime_type="text/plain")
|
||||
assert text_loader.can_handle("test.unknown", mime_type="application/json")
|
||||
assert not text_loader.can_handle("test.unknown", mime_type="application/pdf")
|
||||
|
||||
def test_can_handle_text_file_heuristic(self, text_loader, temp_text_file):
|
||||
"""Test handling of text files by content heuristic."""
|
||||
# Remove extension to force heuristic check
|
||||
no_ext_path = temp_text_file.replace(".txt", "")
|
||||
os.rename(temp_text_file, no_ext_path)
|
||||
|
||||
try:
|
||||
assert text_loader.can_handle(no_ext_path)
|
||||
finally:
|
||||
if os.path.exists(no_ext_path):
|
||||
os.unlink(no_ext_path)
|
||||
|
||||
def test_cannot_handle_binary_file(self, text_loader, temp_binary_file):
|
||||
"""Test that binary files are not handled."""
|
||||
assert not text_loader.can_handle(temp_binary_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_text_file(self, text_loader, temp_text_file):
|
||||
"""Test loading a text file."""
|
||||
result = await text_loader.load(temp_text_file)
|
||||
|
||||
assert isinstance(result.content, str)
|
||||
assert "This is a test file." in result.content
|
||||
assert result.content_type == ContentType.TEXT
|
||||
assert result.metadata["loader"] == "text_loader"
|
||||
assert result.metadata["name"] == os.path.basename(temp_text_file)
|
||||
assert result.metadata["lines"] == 2
|
||||
assert result.metadata["encoding"] == "utf-8"
|
||||
assert result.source_info["file_path"] == temp_text_file
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_with_custom_encoding(self, text_loader):
|
||||
"""Test loading with custom encoding."""
|
||||
# Create a file with latin-1 encoding
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".txt", delete=False, encoding="latin-1"
|
||||
) as f:
|
||||
f.write("Test with åéîøü characters")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = await text_loader.load(temp_path, encoding="latin-1")
|
||||
assert "åéîøü" in result.content
|
||||
assert result.metadata["encoding"] == "latin-1"
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_with_fallback_encoding(self, text_loader):
|
||||
"""Test automatic fallback to latin-1 encoding."""
|
||||
# Create a file with latin-1 content but try to read as utf-8
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
|
||||
# Write latin-1 encoded bytes that are invalid in utf-8
|
||||
f.write(b"Test with \xe5\xe9\xee\xf8\xfc characters")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Should automatically fallback to latin-1
|
||||
result = await text_loader.load(temp_path)
|
||||
assert result.metadata["encoding"] == "latin-1"
|
||||
assert len(result.content) > 0
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_nonexistent_file(self, text_loader):
|
||||
"""Test loading a file that doesn't exist."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await text_loader.load("/nonexistent/file.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_empty_file(self, text_loader):
|
||||
"""Test loading an empty file."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
# Create empty file
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = await text_loader.load(temp_path)
|
||||
assert result.content == ""
|
||||
assert result.metadata["lines"] == 0
|
||||
assert result.metadata["characters"] == 0
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_no_dependencies(self, text_loader):
|
||||
"""Test that TextLoader has no external dependencies."""
|
||||
assert text_loader.get_dependencies() == []
|
||||
assert text_loader.validate_dependencies() is True
|
||||
Loading…
Add table
Reference in a new issue