Compare commits
14 commits
main
...
new_error_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4324df5a8b | ||
|
|
19e22d14b8 | ||
|
|
dc03a52541 | ||
|
|
f0ba618f0c | ||
|
|
725061fbef | ||
|
|
bf191ae6d0 | ||
|
|
9d423f5e16 | ||
|
|
411e9a6205 | ||
|
|
3429af32c2 | ||
|
|
9110a2b59b | ||
|
|
1cd0fe0dcf | ||
|
|
f2e96d5c62 | ||
|
|
1c378dabdb | ||
|
|
98882ba1d1 |
53 changed files with 4249 additions and 151 deletions
|
|
@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
from cognee.exceptions import CogneeApiError
|
from cognee.exceptions import CogneeApiError
|
||||||
|
from cognee.exceptions.enhanced_exceptions import CogneeBaseError
|
||||||
from cognee.shared.logging_utils import get_logger, setup_logging
|
from cognee.shared.logging_utils import get_logger, setup_logging
|
||||||
from cognee.api.v1.permissions.routers import get_permissions_router
|
from cognee.api.v1.permissions.routers import get_permissions_router
|
||||||
from cognee.api.v1.settings.routers import get_settings_router
|
from cognee.api.v1.settings.routers import get_settings_router
|
||||||
|
|
@ -120,8 +121,27 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(CogneeBaseError)
|
||||||
|
async def enhanced_exception_handler(_: Request, exc: CogneeBaseError) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Enhanced exception handler for the new exception hierarchy.
|
||||||
|
Provides standardized error responses with rich context and user guidance.
|
||||||
|
"""
|
||||||
|
# Log the full stack trace for debugging
|
||||||
|
logger.error(f"Enhanced exception caught: {exc.__class__.__name__}", exc_info=True)
|
||||||
|
|
||||||
|
# Create standardized error response
|
||||||
|
error_response = {"error": exc.to_dict()}
|
||||||
|
|
||||||
|
return JSONResponse(status_code=exc.status_code, content=error_response)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(CogneeApiError)
|
@app.exception_handler(CogneeApiError)
|
||||||
async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
async def legacy_exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Legacy exception handler for backward compatibility.
|
||||||
|
Handles old CogneeApiError instances with fallback formatting.
|
||||||
|
"""
|
||||||
detail = {}
|
detail = {}
|
||||||
|
|
||||||
if exc.name and exc.message and exc.status_code:
|
if exc.name and exc.message and exc.status_code:
|
||||||
|
|
@ -136,7 +156,54 @@ async def exception_handler(_: Request, exc: CogneeApiError) -> JSONResponse:
|
||||||
|
|
||||||
# log the stack trace for easier serverside debugging
|
# log the stack trace for easier serverside debugging
|
||||||
logger.error(format_exc())
|
logger.error(format_exc())
|
||||||
return JSONResponse(status_code=status_code, content={"detail": detail["message"]})
|
|
||||||
|
# Convert to new format for consistency
|
||||||
|
error_response = {
|
||||||
|
"error": {
|
||||||
|
"type": exc.__class__.__name__,
|
||||||
|
"message": detail["message"],
|
||||||
|
"technical_message": detail["message"],
|
||||||
|
"suggestions": [
|
||||||
|
"Check the logs for more details",
|
||||||
|
"Try again or contact support if the issue persists",
|
||||||
|
],
|
||||||
|
"docs_link": "https://docs.cognee.ai/troubleshooting",
|
||||||
|
"is_retryable": False,
|
||||||
|
"context": {},
|
||||||
|
"operation": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse(status_code=status_code, content=error_response)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Global exception handler for any unhandled exceptions.
|
||||||
|
Ensures all errors return a consistent format.
|
||||||
|
"""
|
||||||
|
logger.error(f"Unhandled exception in {request.url.path}: {str(exc)}", exc_info=True)
|
||||||
|
|
||||||
|
# Create a standardized error response for unexpected errors
|
||||||
|
error_response = {
|
||||||
|
"error": {
|
||||||
|
"type": "UnexpectedError",
|
||||||
|
"message": "An unexpected error occurred. Please try again.",
|
||||||
|
"technical_message": str(exc) if app_environment != "prod" else "Internal server error",
|
||||||
|
"suggestions": [
|
||||||
|
"Try your request again",
|
||||||
|
"Check if the issue persists",
|
||||||
|
"Contact support if the problem continues",
|
||||||
|
],
|
||||||
|
"docs_link": "https://docs.cognee.ai/troubleshooting",
|
||||||
|
"is_retryable": True,
|
||||||
|
"context": {"path": str(request.url.path), "method": request.method},
|
||||||
|
"operation": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_response)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,19 @@ async def add(
|
||||||
vector_db_config: dict = None,
|
vector_db_config: dict = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
dataset_id: UUID = None,
|
dataset_id: UUID = None,
|
||||||
|
preferred_loaders: Optional[List[str]] = None,
|
||||||
|
loader_config: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add data to Cognee for knowledge graph processing.
|
Add data to Cognee for knowledge graph processing using a plugin-based loader system.
|
||||||
|
|
||||||
This is the first step in the Cognee workflow - it ingests raw data and prepares it
|
This is the first step in the Cognee workflow - it ingests raw data and prepares it
|
||||||
for processing. The function accepts various data formats including text, files, and
|
for processing. The function accepts various data formats including text, files, and
|
||||||
binary streams, then stores them in a specified dataset for further processing.
|
binary streams, then stores them in a specified dataset for further processing.
|
||||||
|
|
||||||
|
This version supports both the original ingestion system (for backward compatibility)
|
||||||
|
and the new plugin-based loader system (when loader parameters are provided).
|
||||||
|
|
||||||
Prerequisites:
|
Prerequisites:
|
||||||
- **LLM_API_KEY**: Must be set in environment variables for content processing
|
- **LLM_API_KEY**: Must be set in environment variables for content processing
|
||||||
- **Database Setup**: Relational and vector databases must be configured
|
- **Database Setup**: Relational and vector databases must be configured
|
||||||
|
|
@ -38,16 +43,38 @@ async def add(
|
||||||
- **Lists**: Multiple files or text strings in a single call
|
- **Lists**: Multiple files or text strings in a single call
|
||||||
|
|
||||||
Supported File Formats:
|
Supported File Formats:
|
||||||
- Text files (.txt, .md, .csv)
|
- Text files (.txt, .md, .csv) - processed by text_loader
|
||||||
- PDFs (.pdf)
|
- PDFs (.pdf) - processed by pypdf_loader (if available)
|
||||||
- Images (.png, .jpg, .jpeg) - extracted via OCR/vision models
|
- Images (.png, .jpg, .jpeg) - extracted via OCR/vision models
|
||||||
- Audio files (.mp3, .wav) - transcribed to text
|
- Audio files (.mp3, .wav) - transcribed to text
|
||||||
- Code files (.py, .js, .ts, etc.) - parsed for structure and content
|
- Code files (.py, .js, .ts, etc.) - parsed for structure and content
|
||||||
- Office documents (.docx, .pptx)
|
- Office documents (.docx, .pptx) - processed by unstructured_loader (if available)
|
||||||
|
- Data files (.json, .jsonl, .parquet) - processed by dlt_loader (if available)
|
||||||
|
|
||||||
|
Plugin System:
|
||||||
|
The function automatically uses the best available loader for each file type.
|
||||||
|
You can customize this behavior using the loader parameters:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use specific loaders in priority order
|
||||||
|
await cognee.add(
|
||||||
|
"/path/to/document.pdf",
|
||||||
|
preferred_loaders=["pypdf_loader", "text_loader"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure loader-specific options
|
||||||
|
await cognee.add(
|
||||||
|
"/path/to/document.pdf",
|
||||||
|
loader_config={
|
||||||
|
"pypdf_loader": {"strict": False},
|
||||||
|
"unstructured_loader": {"strategy": "hi_res"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Workflow:
|
Workflow:
|
||||||
1. **Data Resolution**: Resolves file paths and validates accessibility
|
1. **Data Resolution**: Resolves file paths and validates accessibility
|
||||||
2. **Content Extraction**: Extracts text content from various file formats
|
2. **Content Extraction**: Uses plugin system or falls back to existing classification
|
||||||
3. **Dataset Storage**: Stores processed content in the specified dataset
|
3. **Dataset Storage**: Stores processed content in the specified dataset
|
||||||
4. **Metadata Tracking**: Records file metadata, timestamps, and user permissions
|
4. **Metadata Tracking**: Records file metadata, timestamps, and user permissions
|
||||||
5. **Permission Assignment**: Grants user read/write/delete/share permissions on dataset
|
5. **Permission Assignment**: Grants user read/write/delete/share permissions on dataset
|
||||||
|
|
@ -70,6 +97,10 @@ async def add(
|
||||||
vector_db_config: Optional configuration for vector database (for custom setups).
|
vector_db_config: Optional configuration for vector database (for custom setups).
|
||||||
graph_db_config: Optional configuration for graph database (for custom setups).
|
graph_db_config: Optional configuration for graph database (for custom setups).
|
||||||
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
dataset_id: Optional specific dataset UUID to use instead of dataset_name.
|
||||||
|
preferred_loaders: Optional list of loader names to try first (e.g., ["pypdf_loader", "text_loader"]).
|
||||||
|
If not provided, uses default loader priority.
|
||||||
|
loader_config: Optional configuration for specific loaders. Dictionary mapping loader names
|
||||||
|
to their configuration options (e.g., {"pypdf_loader": {"strict": False}}).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
PipelineRunInfo: Information about the ingestion pipeline execution including:
|
||||||
|
|
@ -138,10 +169,32 @@ async def add(
|
||||||
UnsupportedFileTypeError: If file format cannot be processed
|
UnsupportedFileTypeError: If file format cannot be processed
|
||||||
InvalidValueError: If LLM_API_KEY is not set or invalid
|
InvalidValueError: If LLM_API_KEY is not set or invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Determine which ingestion system to use
|
||||||
|
# use_plugin_system = preferred_loaders is not None or loader_config is not None
|
||||||
|
|
||||||
|
# if use_plugin_system:
|
||||||
|
# # Use new plugin-based ingestion system
|
||||||
|
from cognee.tasks.ingestion.plugin_ingest_data import plugin_ingest_data
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(resolve_data_directories, include_subdirectories=True),
|
Task(resolve_data_directories, include_subdirectories=True),
|
||||||
Task(ingest_data, dataset_name, user, node_set, dataset_id),
|
Task(
|
||||||
|
plugin_ingest_data,
|
||||||
|
dataset_name,
|
||||||
|
user,
|
||||||
|
node_set,
|
||||||
|
dataset_id,
|
||||||
|
preferred_loaders,
|
||||||
|
loader_config,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
# else:
|
||||||
|
# # Use existing ingestion system for backward compatibility
|
||||||
|
# tasks = [
|
||||||
|
# Task(resolve_data_directories, include_subdirectories=True),
|
||||||
|
# Task(ingest_data, dataset_name, user, node_set, dataset_id),
|
||||||
|
# ]
|
||||||
|
|
||||||
pipeline_run_info = None
|
pipeline_run_info = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,13 @@ import requests
|
||||||
|
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_authenticated_user
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
from cognee.exceptions import (
|
||||||
|
UnsupportedFileFormatError,
|
||||||
|
FileAccessError,
|
||||||
|
DatasetNotFoundError,
|
||||||
|
CogneeValidationError,
|
||||||
|
CogneeSystemError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -49,49 +56,143 @@ def get_add_router() -> APIRouter:
|
||||||
- Any relevant metadata from the ingestion process
|
- Any relevant metadata from the ingestion process
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **400 Bad Request**: Neither datasetId nor datasetName provided
|
- **400 Bad Request**: Missing required parameters or invalid input
|
||||||
- **409 Conflict**: Error during add operation
|
- **422 Unprocessable Entity**: Unsupported file format or validation error
|
||||||
- **403 Forbidden**: User doesn't have permission to add to dataset
|
- **403 Forbidden**: User doesn't have permission to add to dataset
|
||||||
|
- **500 Internal Server Error**: System error during processing
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
- To add data to datasets not owned by the user, use dataset_id (when ENABLE_BACKEND_ACCESS_CONTROL is set to True)
|
- To add data to datasets not owned by the user, use dataset_id (when ENABLE_BACKEND_ACCESS_CONTROL is set to True)
|
||||||
- GitHub repositories are cloned and all files are processed
|
- GitHub repositories are cloned and all files are processed
|
||||||
- HTTP URLs are fetched and their content is processed
|
- HTTP URLs are fetched and their content is processed
|
||||||
- The ALLOW_HTTP_REQUESTS environment variable controls URL processing
|
- The ALLOW_HTTP_REQUESTS environment variable controls URL processing
|
||||||
|
- Enhanced error messages provide specific guidance for fixing issues
|
||||||
"""
|
"""
|
||||||
from cognee.api.v1.add import add as cognee_add
|
from cognee.api.v1.add import add as cognee_add
|
||||||
|
|
||||||
|
# Input validation with enhanced exceptions
|
||||||
if not datasetId and not datasetName:
|
if not datasetId and not datasetName:
|
||||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
raise CogneeValidationError(
|
||||||
|
message="Either datasetId or datasetName must be provided",
|
||||||
|
user_message="You must specify either a dataset name or dataset ID.",
|
||||||
|
suggestions=[
|
||||||
|
"Provide a datasetName parameter (e.g., 'my_dataset')",
|
||||||
|
"Provide a datasetId parameter with a valid UUID",
|
||||||
|
"Check the API documentation for parameter examples",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/api/add",
|
||||||
|
context={"provided_dataset_name": datasetName, "provided_dataset_id": datasetId},
|
||||||
|
operation="add",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
if not data or len(data) == 0:
|
||||||
|
raise CogneeValidationError(
|
||||||
|
message="No data provided for upload",
|
||||||
|
user_message="You must provide data to add to the dataset.",
|
||||||
|
suggestions=[
|
||||||
|
"Upload one or more files",
|
||||||
|
"Provide a valid URL (if URL processing is enabled)",
|
||||||
|
"Check that your request includes the data parameter",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||||
|
operation="add",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Adding {len(data)} items to dataset",
|
||||||
|
extra={
|
||||||
|
"dataset_name": datasetName,
|
||||||
|
"dataset_id": datasetId,
|
||||||
|
"user_id": user.id,
|
||||||
|
"item_count": len(data),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle URL-based data (GitHub repos, HTTP URLs)
|
||||||
if (
|
if (
|
||||||
isinstance(data, str)
|
len(data) == 1
|
||||||
and data.startswith("http")
|
and hasattr(data[0], "filename")
|
||||||
|
and isinstance(data[0].filename, str)
|
||||||
|
and data[0].filename.startswith("http")
|
||||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
||||||
):
|
):
|
||||||
if "github" in data:
|
url = data[0].filename
|
||||||
|
|
||||||
|
if "github" in url:
|
||||||
|
try:
|
||||||
# Perform git clone if the URL is from GitHub
|
# Perform git clone if the URL is from GitHub
|
||||||
repo_name = data.split("/")[-1].replace(".git", "")
|
repo_name = url.split("/")[-1].replace(".git", "")
|
||||||
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
|
subprocess.run(["git", "clone", url, f".data/{repo_name}"], check=True)
|
||||||
# TODO: Update add call with dataset info
|
# TODO: Update add call with dataset info
|
||||||
await cognee_add(
|
result = await cognee_add(
|
||||||
"data://.data/",
|
"data://.data/",
|
||||||
f"{repo_name}",
|
f"{repo_name}",
|
||||||
)
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise CogneeSystemError(
|
||||||
|
message=f"Failed to clone GitHub repository: {e}",
|
||||||
|
user_message=f"Could not clone the GitHub repository '{url}'.",
|
||||||
|
suggestions=[
|
||||||
|
"Check if the repository URL is correct",
|
||||||
|
"Verify the repository is public or you have access",
|
||||||
|
"Try cloning the repository manually to test access",
|
||||||
|
],
|
||||||
|
context={"url": url, "repo_name": repo_name, "error": str(e)},
|
||||||
|
operation="add",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fetch and store the data from other types of URL using curl
|
try:
|
||||||
response = requests.get(data)
|
# Fetch and store the data from other types of URL
|
||||||
|
response = requests.get(url, timeout=30)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
file_data = await response.content()
|
file_data = response.content
|
||||||
# TODO: Update add call with dataset info
|
# TODO: Update add call with dataset info
|
||||||
return await cognee_add(file_data)
|
result = await cognee_add(file_data)
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise CogneeSystemError(
|
||||||
|
message=f"Failed to fetch URL: {e}",
|
||||||
|
user_message=f"Could not fetch content from '{url}'.",
|
||||||
|
suggestions=[
|
||||||
|
"Check if the URL is accessible",
|
||||||
|
"Verify your internet connection",
|
||||||
|
"Try accessing the URL in a browser",
|
||||||
|
"Check if the URL requires authentication",
|
||||||
|
],
|
||||||
|
context={"url": url, "error": str(e)},
|
||||||
|
operation="add",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
# Handle regular file uploads
|
||||||
|
# Validate file types before processing
|
||||||
|
supported_extensions = [
|
||||||
|
".txt",
|
||||||
|
".pdf",
|
||||||
|
".docx",
|
||||||
|
".md",
|
||||||
|
".csv",
|
||||||
|
".json",
|
||||||
|
".py",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
]
|
||||||
|
|
||||||
return add_run.model_dump()
|
for file in data:
|
||||||
except Exception as error:
|
if file.filename:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
file_ext = os.path.splitext(file.filename)[1].lower()
|
||||||
|
if file_ext and file_ext not in supported_extensions:
|
||||||
|
raise UnsupportedFileFormatError(
|
||||||
|
file_path=file.filename, supported_formats=supported_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the files
|
||||||
|
result = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Successfully added data to dataset",
|
||||||
|
extra={"dataset_name": datasetName, "dataset_id": datasetId, "user_id": user.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.model_dump() if hasattr(result, "model_dump") else result
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,13 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
||||||
remove_queue,
|
remove_queue,
|
||||||
)
|
)
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.exceptions import (
|
||||||
|
CogneeValidationError,
|
||||||
|
EmptyDatasetError,
|
||||||
|
DatasetNotFoundError,
|
||||||
|
MissingAPIKeyError,
|
||||||
|
NoDataToProcessError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("api.cognify")
|
logger = get_logger("api.cognify")
|
||||||
|
|
@ -66,8 +73,10 @@ def get_cognify_router() -> APIRouter:
|
||||||
- **Background execution**: Pipeline run metadata including pipeline_run_id for status monitoring via WebSocket subscription
|
- **Background execution**: Pipeline run metadata including pipeline_run_id for status monitoring via WebSocket subscription
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **400 Bad Request**: When neither datasets nor dataset_ids are provided, or when specified datasets don't exist
|
- **400 Bad Request**: Missing required parameters or invalid input
|
||||||
- **409 Conflict**: When processing fails due to system errors, missing LLM API keys, database connection failures, or corrupted content
|
- **422 Unprocessable Entity**: No data to process or validation errors
|
||||||
|
- **404 Not Found**: Specified datasets don't exist
|
||||||
|
- **500 Internal Server Error**: System errors, missing API keys, database connection failures
|
||||||
|
|
||||||
## Example Request
|
## Example Request
|
||||||
```json
|
```json
|
||||||
|
|
@ -84,23 +93,53 @@ def get_cognify_router() -> APIRouter:
|
||||||
## Next Steps
|
## Next Steps
|
||||||
After successful processing, use the search endpoints to query the generated knowledge graph for insights, relationships, and semantic search.
|
After successful processing, use the search endpoints to query the generated knowledge graph for insights, relationships, and semantic search.
|
||||||
"""
|
"""
|
||||||
|
# Input validation with enhanced exceptions
|
||||||
if not payload.datasets and not payload.dataset_ids:
|
if not payload.datasets and not payload.dataset_ids:
|
||||||
return JSONResponse(
|
raise CogneeValidationError(
|
||||||
status_code=400, content={"error": "No datasets or dataset_ids provided"}
|
message="No datasets or dataset_ids provided",
|
||||||
|
user_message="You must specify which datasets to process.",
|
||||||
|
suggestions=[
|
||||||
|
"Provide dataset names using the 'datasets' parameter",
|
||||||
|
"Provide dataset UUIDs using the 'dataset_ids' parameter",
|
||||||
|
"Use cognee.datasets() to see available datasets",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/api/cognify",
|
||||||
|
context={
|
||||||
|
"provided_datasets": payload.datasets,
|
||||||
|
"provided_dataset_ids": payload.dataset_ids,
|
||||||
|
},
|
||||||
|
operation="cognify",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check for LLM API key early to provide better error messaging
|
||||||
|
llm_api_key = os.getenv("LLM_API_KEY")
|
||||||
|
if not llm_api_key:
|
||||||
|
raise MissingAPIKeyError(service="LLM", env_var="LLM_API_KEY")
|
||||||
|
|
||||||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||||
|
|
||||||
try:
|
|
||||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Starting cognify process for user {user.id}",
|
||||||
|
extra={
|
||||||
|
"user_id": user.id,
|
||||||
|
"datasets": datasets,
|
||||||
|
"run_in_background": payload.run_in_background,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The enhanced exception handler will catch and format any errors from cognee_cognify
|
||||||
cognify_run = await cognee_cognify(
|
cognify_run = await cognee_cognify(
|
||||||
datasets, user, run_in_background=payload.run_in_background
|
datasets, user, run_in_background=payload.run_in_background
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Cognify process completed for user {user.id}",
|
||||||
|
extra={"user_id": user.id, "datasets": datasets},
|
||||||
|
)
|
||||||
|
|
||||||
return cognify_run
|
return cognify_run
|
||||||
except Exception as error:
|
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
|
||||||
|
|
||||||
@router.websocket("/subscribe/{pipeline_run_id}")
|
@router.websocket("/subscribe/{pipeline_run_id}")
|
||||||
async def subscribe_to_cognify_info(websocket: WebSocket, pipeline_run_id: str):
|
async def subscribe_to_cognify_info(websocket: WebSocket, pipeline_run_id: str):
|
||||||
|
|
@ -124,6 +163,14 @@ def get_cognify_router() -> APIRouter:
|
||||||
user_manager=user_manager,
|
user_manager=user_manager,
|
||||||
bearer=None,
|
bearer=None,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"WebSocket user authenticated for pipeline {pipeline_run_id}",
|
||||||
|
extra={
|
||||||
|
"user_id": user.id,
|
||||||
|
"user_email": user.email,
|
||||||
|
"pipeline_run_id": str(pipeline_run_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(f"Authentication failed: {str(error)}")
|
logger.error(f"Authentication failed: {str(error)}")
|
||||||
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
await websocket.close(code=WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||||
|
|
@ -135,31 +182,43 @@ def get_cognify_router() -> APIRouter:
|
||||||
|
|
||||||
initialize_queue(pipeline_run_id)
|
initialize_queue(pipeline_run_id)
|
||||||
|
|
||||||
while True:
|
|
||||||
pipeline_run_info = get_from_queue(pipeline_run_id)
|
|
||||||
|
|
||||||
if not pipeline_run_info:
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(pipeline_run_info, PipelineRunInfo):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await websocket.send_json(
|
# If the pipeline is already completed, send the completion status
|
||||||
{
|
if isinstance(pipeline_run, PipelineRunCompleted):
|
||||||
"pipeline_run_id": str(pipeline_run_info.pipeline_run_id),
|
graph_data = await get_formatted_graph_data()
|
||||||
"status": pipeline_run_info.status,
|
pipeline_run.payload = {
|
||||||
"payload": await get_formatted_graph_data(pipeline_run.dataset_id, user.id),
|
"nodes": graph_data.get("nodes", []),
|
||||||
|
"edges": graph_data.get("edges", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await websocket.send_json(pipeline_run.model_dump())
|
||||||
|
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream pipeline updates
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pipeline_run_info = await asyncio.wait_for(
|
||||||
|
get_from_queue(pipeline_run_id), timeout=10.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if pipeline_run_info:
|
||||||
|
await websocket.send_json(pipeline_run_info.model_dump())
|
||||||
|
|
||||||
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
if isinstance(pipeline_run_info, PipelineRunCompleted):
|
||||||
remove_queue(pipeline_run_id)
|
|
||||||
await websocket.close(code=WS_1000_NORMAL_CLOSURE)
|
|
||||||
break
|
break
|
||||||
except WebSocketDisconnect:
|
except asyncio.TimeoutError:
|
||||||
remove_queue(pipeline_run_id)
|
# Send a heartbeat to keep the connection alive
|
||||||
|
await websocket.send_json({"type": "heartbeat"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in WebSocket communication: {str(e)}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket disconnected for pipeline {pipeline_run_id}")
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"WebSocket error: {str(error)}")
|
||||||
|
finally:
|
||||||
|
remove_queue(pipeline_run_id)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,12 @@ async def delete(
|
||||||
# Get the content hash for deletion
|
# Get the content hash for deletion
|
||||||
content_hash = data_point.content_hash
|
content_hash = data_point.content_hash
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.info(
|
||||||
|
f"🔍 Retrieved from database - data_id: {data_id}, content_hash: {content_hash}"
|
||||||
|
)
|
||||||
|
logger.info(f"🔍 Document name in database: {data_point.name}")
|
||||||
|
|
||||||
# Use the existing comprehensive deletion logic
|
# Use the existing comprehensive deletion logic
|
||||||
return await delete_single_document(content_hash, dataset.id, mode)
|
return await delete_single_document(content_hash, dataset.id, mode)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,24 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi import Depends, APIRouter
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from cognee.api.DTO import InDTO
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.api.DTO import InDTO, OutDTO
|
|
||||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
|
||||||
from cognee.modules.users.models import User
|
|
||||||
from cognee.modules.search.operations import get_history
|
|
||||||
from cognee.modules.users.methods import get_authenticated_user
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||||
|
from cognee.modules.data.methods import get_history
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.exceptions import UnsupportedSearchTypeError, InvalidQueryError, NoDataToProcessError
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
# Note: Datasets sent by name will only map to datasets owned by the request sender
|
|
||||||
# To search for datasets not owned by the request sender dataset UUID is needed
|
|
||||||
class SearchPayloadDTO(InDTO):
|
class SearchPayloadDTO(InDTO):
|
||||||
search_type: SearchType
|
search_type: SearchType
|
||||||
datasets: Optional[list[str]] = None
|
datasets: Optional[List[str]] = None
|
||||||
dataset_ids: Optional[list[UUID]] = None
|
dataset_ids: Optional[List[UUID]] = None
|
||||||
query: str
|
query: str
|
||||||
top_k: Optional[int] = 10
|
top_k: Optional[int] = 10
|
||||||
|
|
||||||
|
|
@ -24,36 +26,23 @@ class SearchPayloadDTO(InDTO):
|
||||||
def get_search_router() -> APIRouter:
|
def get_search_router() -> APIRouter:
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
class SearchHistoryItem(OutDTO):
|
@router.get("/history", response_model=list)
|
||||||
id: UUID
|
|
||||||
text: str
|
|
||||||
user: str
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
@router.get("", response_model=list[SearchHistoryItem])
|
|
||||||
async def get_search_history(user: User = Depends(get_authenticated_user)):
|
async def get_search_history(user: User = Depends(get_authenticated_user)):
|
||||||
"""
|
"""
|
||||||
Get search history for the authenticated user.
|
Get search history for the authenticated user.
|
||||||
|
|
||||||
This endpoint retrieves the search history for the authenticated user,
|
This endpoint retrieves the search history for the current user,
|
||||||
returning a list of previously executed searches with their timestamps.
|
showing previous queries and their results.
|
||||||
|
|
||||||
## Response
|
## Response
|
||||||
Returns a list of search history items containing:
|
Returns a list of historical search queries and their metadata.
|
||||||
- **id**: Unique identifier for the search
|
|
||||||
- **text**: The search query text
|
|
||||||
- **user**: User who performed the search
|
|
||||||
- **created_at**: When the search was performed
|
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **500 Internal Server Error**: Error retrieving search history
|
- **500 Internal Server Error**: Database or system error while retrieving history
|
||||||
"""
|
"""
|
||||||
try:
|
# Remove try-catch to let enhanced exception handler deal with it
|
||||||
history = await get_history(user.id, limit=0)
|
history = await get_history(user.id, limit=0)
|
||||||
|
|
||||||
return history
|
return history
|
||||||
except Exception as error:
|
|
||||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
|
||||||
|
|
||||||
@router.post("", response_model=list)
|
@router.post("", response_model=list)
|
||||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||||
|
|
@ -75,17 +64,46 @@ def get_search_router() -> APIRouter:
|
||||||
Returns a list of search results containing relevant nodes from the graph.
|
Returns a list of search results containing relevant nodes from the graph.
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **409 Conflict**: Error during search operation
|
- **400 Bad Request**: Invalid query or search parameters
|
||||||
- **403 Forbidden**: User doesn't have permission to search datasets (returns empty list)
|
- **404 Not Found**: No data found to search
|
||||||
|
- **422 Unprocessable Entity**: Unsupported search type
|
||||||
|
- **403 Forbidden**: User doesn't have permission to search datasets
|
||||||
|
- **500 Internal Server Error**: System error during search
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
- Datasets sent by name will only map to datasets owned by the request sender
|
- Datasets sent by name will only map to datasets owned by the request sender
|
||||||
- To search datasets not owned by the request sender, dataset UUID is needed
|
- To search datasets not owned by the request sender, dataset UUID is needed
|
||||||
- If permission is denied, returns empty list instead of error
|
- Enhanced error messages provide actionable suggestions for fixing issues
|
||||||
"""
|
"""
|
||||||
from cognee.api.v1.search import search as cognee_search
|
from cognee.api.v1.search import search as cognee_search
|
||||||
|
|
||||||
|
# Input validation with enhanced exceptions
|
||||||
|
if not payload.query or not payload.query.strip():
|
||||||
|
raise InvalidQueryError(query=payload.query or "", reason="Query cannot be empty")
|
||||||
|
|
||||||
|
if len(payload.query.strip()) < 2:
|
||||||
|
raise InvalidQueryError(
|
||||||
|
query=payload.query, reason="Query must be at least 2 characters long"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if search type is supported
|
||||||
try:
|
try:
|
||||||
|
search_type = payload.search_type
|
||||||
|
logger.info(
|
||||||
|
f"Search type validated: {search_type.value}",
|
||||||
|
extra={
|
||||||
|
"search_type": search_type.value,
|
||||||
|
"user_id": user.id,
|
||||||
|
"query_length": len(payload.query),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise UnsupportedSearchTypeError(
|
||||||
|
search_type=str(payload.search_type), supported_types=[t.value for t in SearchType]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Permission denied errors will be caught and handled by the enhanced exception handler
|
||||||
|
# Other exceptions will also be properly formatted by the global handler
|
||||||
results = await cognee_search(
|
results = await cognee_search(
|
||||||
query_text=payload.query,
|
query_text=payload.query,
|
||||||
query_type=payload.search_type,
|
query_type=payload.search_type,
|
||||||
|
|
@ -95,10 +113,10 @@ def get_search_router() -> APIRouter:
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
# If no results found, that's not necessarily an error, just return empty list
|
||||||
except PermissionDeniedError:
|
if not results:
|
||||||
return []
|
return []
|
||||||
except Exception as error:
|
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return results
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""
|
"""
|
||||||
Custom exceptions for the Cognee API.
|
Custom exceptions for the Cognee API.
|
||||||
|
|
||||||
This module defines a set of exceptions for handling various application errors,
|
This module defines a comprehensive set of exceptions for handling various application errors,
|
||||||
such as service failures, resource conflicts, and invalid operations.
|
with enhanced error context, user-friendly messages, and actionable suggestions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Import original exceptions for backward compatibility
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
CogneeApiError,
|
CogneeApiError,
|
||||||
ServiceError,
|
ServiceError,
|
||||||
|
|
@ -12,3 +13,83 @@ from .exceptions import (
|
||||||
InvalidAttributeError,
|
InvalidAttributeError,
|
||||||
CriticalError,
|
CriticalError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Import enhanced exception hierarchy
|
||||||
|
from .enhanced_exceptions import (
|
||||||
|
CogneeBaseError,
|
||||||
|
CogneeUserError,
|
||||||
|
CogneeSystemError,
|
||||||
|
CogneeTransientError,
|
||||||
|
CogneeConfigurationError,
|
||||||
|
CogneeValidationError,
|
||||||
|
CogneeAuthenticationError,
|
||||||
|
CogneePermissionError,
|
||||||
|
CogneeNotFoundError,
|
||||||
|
CogneeRateLimitError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import domain-specific exceptions
|
||||||
|
from .domain_exceptions import (
|
||||||
|
# Data/Input Errors
|
||||||
|
UnsupportedFileFormatError,
|
||||||
|
EmptyDatasetError,
|
||||||
|
DatasetNotFoundError,
|
||||||
|
InvalidQueryError,
|
||||||
|
FileAccessError,
|
||||||
|
# Processing Errors
|
||||||
|
LLMConnectionError,
|
||||||
|
LLMRateLimitError,
|
||||||
|
ProcessingTimeoutError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
InsufficientResourcesError,
|
||||||
|
# Configuration Errors
|
||||||
|
MissingAPIKeyError,
|
||||||
|
InvalidDatabaseConfigError,
|
||||||
|
UnsupportedSearchTypeError,
|
||||||
|
# Pipeline Errors
|
||||||
|
PipelineExecutionError,
|
||||||
|
DataExtractionError,
|
||||||
|
NoDataToProcessError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For backward compatibility, create aliases
|
||||||
|
# These will allow existing code to continue working while we migrate
|
||||||
|
DatasetNotFoundError_Legacy = InvalidValueError # For existing dataset not found errors
|
||||||
|
PermissionDeniedError_Legacy = CogneeApiError # For existing permission errors
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Original exceptions (backward compatibility)
|
||||||
|
"CogneeApiError",
|
||||||
|
"ServiceError",
|
||||||
|
"InvalidValueError",
|
||||||
|
"InvalidAttributeError",
|
||||||
|
"CriticalError",
|
||||||
|
# Enhanced base exceptions
|
||||||
|
"CogneeBaseError",
|
||||||
|
"CogneeUserError",
|
||||||
|
"CogneeSystemError",
|
||||||
|
"CogneeTransientError",
|
||||||
|
"CogneeConfigurationError",
|
||||||
|
"CogneeValidationError",
|
||||||
|
"CogneeAuthenticationError",
|
||||||
|
"CogneePermissionError",
|
||||||
|
"CogneeNotFoundError",
|
||||||
|
"CogneeRateLimitError",
|
||||||
|
# Domain-specific exceptions
|
||||||
|
"UnsupportedFileFormatError",
|
||||||
|
"EmptyDatasetError",
|
||||||
|
"DatasetNotFoundError",
|
||||||
|
"InvalidQueryError",
|
||||||
|
"FileAccessError",
|
||||||
|
"LLMConnectionError",
|
||||||
|
"LLMRateLimitError",
|
||||||
|
"ProcessingTimeoutError",
|
||||||
|
"DatabaseConnectionError",
|
||||||
|
"InsufficientResourcesError",
|
||||||
|
"MissingAPIKeyError",
|
||||||
|
"InvalidDatabaseConfigError",
|
||||||
|
"UnsupportedSearchTypeError",
|
||||||
|
"PipelineExecutionError",
|
||||||
|
"DataExtractionError",
|
||||||
|
"NoDataToProcessError",
|
||||||
|
]
|
||||||
|
|
|
||||||
337
cognee/exceptions/domain_exceptions.py
Normal file
337
cognee/exceptions/domain_exceptions.py
Normal file
|
|
@ -0,0 +1,337 @@
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from .enhanced_exceptions import (
|
||||||
|
CogneeUserError,
|
||||||
|
CogneeSystemError,
|
||||||
|
CogneeTransientError,
|
||||||
|
CogneeConfigurationError,
|
||||||
|
CogneeValidationError,
|
||||||
|
CogneeNotFoundError,
|
||||||
|
CogneePermissionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== DATA/INPUT ERRORS (User-fixable) ==========
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileFormatError(CogneeValidationError):
|
||||||
|
"""File format not supported by Cognee"""
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, supported_formats: List[str], **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"File format not supported: {file_path}",
|
||||||
|
user_message=f"The file '{file_path}' has an unsupported format.",
|
||||||
|
suggestions=[
|
||||||
|
f"Use one of these supported formats: {', '.join(supported_formats)}",
|
||||||
|
"Convert your file to a supported format",
|
||||||
|
"Check our documentation for the complete list of supported formats",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/file-formats",
|
||||||
|
context={"file_path": file_path, "supported_formats": supported_formats},
|
||||||
|
operation="add",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyDatasetError(CogneeValidationError):
|
||||||
|
"""Dataset is empty or contains no processable content"""
|
||||||
|
|
||||||
|
def __init__(self, dataset_name: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Dataset '{dataset_name}' is empty",
|
||||||
|
user_message=f"The dataset '{dataset_name}' contains no data to process.",
|
||||||
|
suggestions=[
|
||||||
|
"Add some data to the dataset first using cognee.add()",
|
||||||
|
"Check if your files contain readable text content",
|
||||||
|
"Verify that your data was uploaded successfully",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||||
|
context={"dataset_name": dataset_name},
|
||||||
|
operation="cognify",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetNotFoundError(CogneeNotFoundError):
|
||||||
|
"""Dataset not found or not accessible"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, dataset_identifier: str, available_datasets: Optional[List[str]] = None, **kwargs
|
||||||
|
):
|
||||||
|
suggestions = ["Check the dataset name for typos"]
|
||||||
|
if available_datasets:
|
||||||
|
suggestions.extend(
|
||||||
|
[
|
||||||
|
f"Available datasets: {', '.join(available_datasets)}",
|
||||||
|
"Use cognee.datasets() to see all your datasets",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suggestions.append("Create the dataset first by adding data to it")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
message=f"Dataset not found: {dataset_identifier}",
|
||||||
|
user_message=f"Could not find dataset '{dataset_identifier}'.",
|
||||||
|
suggestions=suggestions,
|
||||||
|
docs_link="https://docs.cognee.ai/guides/datasets",
|
||||||
|
context={
|
||||||
|
"dataset_identifier": dataset_identifier,
|
||||||
|
"available_datasets": available_datasets,
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidQueryError(CogneeValidationError):
|
||||||
|
"""Search query is invalid or malformed"""
|
||||||
|
|
||||||
|
def __init__(self, query: str, reason: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid query: {reason}",
|
||||||
|
user_message=f"Your search query '{query}' is invalid: {reason}",
|
||||||
|
suggestions=[
|
||||||
|
"Try rephrasing your query",
|
||||||
|
"Use simpler, more specific terms",
|
||||||
|
"Check our query examples in the documentation",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/search",
|
||||||
|
context={"query": query, "reason": reason},
|
||||||
|
operation="search",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileAccessError(CogneeUserError):
|
||||||
|
"""Cannot access or read the specified file"""
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, reason: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Cannot access file: {file_path} - {reason}",
|
||||||
|
user_message=f"Unable to read the file '{file_path}': {reason}",
|
||||||
|
suggestions=[
|
||||||
|
"Check if the file exists at the specified path",
|
||||||
|
"Verify you have read permissions for the file",
|
||||||
|
"Ensure the file is not locked by another application",
|
||||||
|
],
|
||||||
|
context={"file_path": file_path, "reason": reason},
|
||||||
|
operation="add",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== PROCESSING ERRORS (System/LLM errors) ==========
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConnectionError(CogneeTransientError):
|
||||||
|
"""LLM service connection failure"""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, model: str, reason: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"LLM connection failed: {provider}/{model} - {reason}",
|
||||||
|
user_message=f"Cannot connect to the {provider} language model service.",
|
||||||
|
suggestions=[
|
||||||
|
"Check your internet connection",
|
||||||
|
"Verify your API key is correct and has sufficient credits",
|
||||||
|
"Try again in a few moments",
|
||||||
|
"Check the service status page",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/troubleshooting/llm-connection",
|
||||||
|
context={"provider": provider, "model": model, "reason": reason},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRateLimitError(CogneeTransientError):
|
||||||
|
"""LLM service rate limit exceeded"""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, retry_after: Optional[int] = None, **kwargs):
|
||||||
|
suggestions = [
|
||||||
|
"Wait a moment before retrying",
|
||||||
|
"Consider upgrading your API plan",
|
||||||
|
"Use smaller batch sizes to reduce token usage",
|
||||||
|
]
|
||||||
|
if retry_after:
|
||||||
|
suggestions.insert(0, f"Wait {retry_after} seconds before retrying")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
message=f"Rate limit exceeded for {provider}",
|
||||||
|
user_message=f"You've exceeded the rate limit for {provider}.",
|
||||||
|
suggestions=suggestions,
|
||||||
|
context={"provider": provider, "retry_after": retry_after},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingTimeoutError(CogneeTransientError):
|
||||||
|
"""Processing operation timed out"""
|
||||||
|
|
||||||
|
def __init__(self, operation: str, timeout_seconds: int, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Operation '{operation}' timed out after {timeout_seconds}s",
|
||||||
|
user_message=f"The {operation} operation took too long and was cancelled.",
|
||||||
|
suggestions=[
|
||||||
|
"Try processing smaller amounts of data at a time",
|
||||||
|
"Check your internet connection stability",
|
||||||
|
"Retry the operation",
|
||||||
|
"Use background processing for large datasets",
|
||||||
|
],
|
||||||
|
context={"operation": operation, "timeout_seconds": timeout_seconds},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectionError(CogneeSystemError):
|
||||||
|
"""Database connection failure"""
|
||||||
|
|
||||||
|
def __init__(self, db_type: str, reason: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"{db_type} database connection failed: {reason}",
|
||||||
|
user_message=f"Cannot connect to the {db_type} database.",
|
||||||
|
suggestions=[
|
||||||
|
"Check if the database service is running",
|
||||||
|
"Verify database connection configuration",
|
||||||
|
"Check network connectivity",
|
||||||
|
"Contact support if the issue persists",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/troubleshooting/database",
|
||||||
|
context={"db_type": db_type, "reason": reason},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InsufficientResourcesError(CogneeSystemError):
|
||||||
|
"""System has insufficient resources to complete the operation"""
|
||||||
|
|
||||||
|
def __init__(self, resource_type: str, required: str, available: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Insufficient {resource_type}: need {required}, have {available}",
|
||||||
|
user_message=f"Not enough {resource_type} available to complete this operation.",
|
||||||
|
suggestions=[
|
||||||
|
"Try processing smaller amounts of data",
|
||||||
|
"Free up system resources",
|
||||||
|
"Wait for other operations to complete",
|
||||||
|
"Consider upgrading your system resources",
|
||||||
|
],
|
||||||
|
context={"resource_type": resource_type, "required": required, "available": available},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== CONFIGURATION ERRORS ==========
|
||||||
|
|
||||||
|
|
||||||
|
class MissingAPIKeyError(CogneeConfigurationError):
|
||||||
|
"""Required API key is missing"""
|
||||||
|
|
||||||
|
def __init__(self, service: str, env_var: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Missing API key for {service}",
|
||||||
|
user_message=f"API key for {service} is not configured.",
|
||||||
|
suggestions=[
|
||||||
|
f"Set the {env_var} environment variable",
|
||||||
|
f"Add your {service} API key to your .env file",
|
||||||
|
"Check the setup documentation for detailed instructions",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/setup/api-keys",
|
||||||
|
context={"service": service, "env_var": env_var},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDatabaseConfigError(CogneeConfigurationError):
|
||||||
|
"""Database configuration is invalid"""
|
||||||
|
|
||||||
|
def __init__(self, db_type: str, config_issue: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid {db_type} database configuration: {config_issue}",
|
||||||
|
user_message=f"The {db_type} database is not properly configured: {config_issue}",
|
||||||
|
suggestions=[
|
||||||
|
"Check your database configuration settings",
|
||||||
|
"Verify connection strings and credentials",
|
||||||
|
"Review the database setup documentation",
|
||||||
|
"Ensure the database server is accessible",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/setup/databases",
|
||||||
|
context={"db_type": db_type, "config_issue": config_issue},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedSearchTypeError(CogneeValidationError):
|
||||||
|
"""Search type is not supported"""
|
||||||
|
|
||||||
|
def __init__(self, search_type: str, supported_types: List[str], **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Unsupported search type: {search_type}",
|
||||||
|
user_message=f"The search type '{search_type}' is not supported.",
|
||||||
|
suggestions=[
|
||||||
|
f"Use one of these supported search types: {', '.join(supported_types)}",
|
||||||
|
"Check the search documentation for available types",
|
||||||
|
"Try using GRAPH_COMPLETION for general queries",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/search-types",
|
||||||
|
context={"search_type": search_type, "supported_types": supported_types},
|
||||||
|
operation="search",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== PIPELINE ERRORS ==========
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineExecutionError(CogneeSystemError):
|
||||||
|
"""Pipeline execution failed"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline_name: str, task_name: str, error_details: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Pipeline '{pipeline_name}' failed at task '{task_name}': {error_details}",
|
||||||
|
user_message=f"Processing failed during the {task_name} step.",
|
||||||
|
suggestions=[
|
||||||
|
"Check the logs for more detailed error information",
|
||||||
|
"Verify your data is in a supported format",
|
||||||
|
"Try processing smaller amounts of data",
|
||||||
|
"Contact support if the issue persists",
|
||||||
|
],
|
||||||
|
context={
|
||||||
|
"pipeline_name": pipeline_name,
|
||||||
|
"task_name": task_name,
|
||||||
|
"error_details": error_details,
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataExtractionError(CogneeSystemError):
|
||||||
|
"""Failed to extract content from data"""
|
||||||
|
|
||||||
|
def __init__(self, source: str, reason: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"Data extraction failed for {source}: {reason}",
|
||||||
|
user_message=f"Could not extract readable content from '{source}'.",
|
||||||
|
suggestions=[
|
||||||
|
"Verify the file is not corrupted",
|
||||||
|
"Try converting to a different format",
|
||||||
|
"Check if the file contains readable text",
|
||||||
|
"Use a supported file format",
|
||||||
|
],
|
||||||
|
context={"source": source, "reason": reason},
|
||||||
|
operation="add",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoDataToProcessError(CogneeValidationError):
|
||||||
|
"""No data available to process"""
|
||||||
|
|
||||||
|
def __init__(self, operation: str, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
message=f"No data available for {operation}",
|
||||||
|
user_message=f"There's no data to process for the {operation} operation.",
|
||||||
|
suggestions=[
|
||||||
|
"Add some data first using cognee.add()",
|
||||||
|
"Check if your previous data upload was successful",
|
||||||
|
"Verify the dataset contains processable content",
|
||||||
|
],
|
||||||
|
docs_link="https://docs.cognee.ai/guides/adding-data",
|
||||||
|
context={"operation": operation},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
184
cognee/exceptions/enhanced_exceptions.py
Normal file
184
cognee/exceptions/enhanced_exceptions.py
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from fastapi import status
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeBaseError(Exception):
|
||||||
|
"""
|
||||||
|
Base exception for all Cognee errors with enhanced context and user experience.
|
||||||
|
|
||||||
|
This class provides a foundation for all Cognee exceptions with:
|
||||||
|
- Rich error context
|
||||||
|
- User-friendly messages
|
||||||
|
- Actionable suggestions
|
||||||
|
- Documentation links
|
||||||
|
- Retry information
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
user_message: Optional[str] = None,
|
||||||
|
suggestions: Optional[List[str]] = None,
|
||||||
|
docs_link: Optional[str] = None,
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
is_retryable: bool = False,
|
||||||
|
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
log_level: str = "ERROR",
|
||||||
|
operation: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.user_message = user_message or message
|
||||||
|
self.suggestions = suggestions or []
|
||||||
|
self.docs_link = docs_link
|
||||||
|
self.context = context or {}
|
||||||
|
self.is_retryable = is_retryable
|
||||||
|
self.status_code = status_code
|
||||||
|
self.operation = operation
|
||||||
|
|
||||||
|
# Automatically log the exception
|
||||||
|
if log_level == "ERROR":
|
||||||
|
logger.error(f"CogneeError in {operation or 'unknown'}: {message}", extra=self.context)
|
||||||
|
elif log_level == "WARNING":
|
||||||
|
logger.warning(
|
||||||
|
f"CogneeWarning in {operation or 'unknown'}: {message}", extra=self.context
|
||||||
|
)
|
||||||
|
elif log_level == "INFO":
|
||||||
|
logger.info(f"CogneeInfo in {operation or 'unknown'}: {message}", extra=self.context)
|
||||||
|
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.__class__.__name__}: {self.message}"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert exception to dictionary for API responses"""
|
||||||
|
return {
|
||||||
|
"type": self.__class__.__name__,
|
||||||
|
"message": self.user_message,
|
||||||
|
"technical_message": self.message,
|
||||||
|
"suggestions": self.suggestions,
|
||||||
|
"docs_link": self.docs_link,
|
||||||
|
"is_retryable": self.is_retryable,
|
||||||
|
"context": self.context,
|
||||||
|
"operation": self.operation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeUserError(CogneeBaseError):
|
||||||
|
"""
|
||||||
|
User-fixable errors (4xx status codes).
|
||||||
|
|
||||||
|
These are errors caused by user input or actions that can be corrected
|
||||||
|
by the user. Examples: invalid file format, missing required field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_400_BAD_REQUEST)
|
||||||
|
kwargs.setdefault("log_level", "WARNING")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeSystemError(CogneeBaseError):
|
||||||
|
"""
|
||||||
|
System/infrastructure errors (5xx status codes).
|
||||||
|
|
||||||
|
These are errors caused by system issues that require technical intervention.
|
||||||
|
Examples: database connection failure, service unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
kwargs.setdefault("log_level", "ERROR")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeTransientError(CogneeBaseError):
|
||||||
|
"""
|
||||||
|
Temporary/retryable errors.
|
||||||
|
|
||||||
|
These are errors that might succeed if retried, often due to temporary
|
||||||
|
resource constraints or network issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||||
|
kwargs.setdefault("is_retryable", True)
|
||||||
|
kwargs.setdefault("log_level", "WARNING")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeConfigurationError(CogneeBaseError):
|
||||||
|
"""
|
||||||
|
Setup/configuration errors.
|
||||||
|
|
||||||
|
These are errors related to missing or invalid configuration that
|
||||||
|
prevent the system from operating correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
|
kwargs.setdefault("log_level", "ERROR")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeValidationError(CogneeUserError):
|
||||||
|
"""
|
||||||
|
Input validation errors.
|
||||||
|
|
||||||
|
Specific type of user error for invalid input data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_422_UNPROCESSABLE_ENTITY)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeAuthenticationError(CogneeUserError):
|
||||||
|
"""
|
||||||
|
Authentication and authorization errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_401_UNAUTHORIZED)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneePermissionError(CogneeUserError):
|
||||||
|
"""
|
||||||
|
Permission denied errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_403_FORBIDDEN)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeNotFoundError(CogneeUserError):
|
||||||
|
"""
|
||||||
|
Resource not found errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_404_NOT_FOUND)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CogneeRateLimitError(CogneeTransientError):
|
||||||
|
"""
|
||||||
|
Rate limiting errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs.setdefault("status_code", status.HTTP_429_TOO_MANY_REQUESTS)
|
||||||
|
kwargs.setdefault(
|
||||||
|
"suggestions",
|
||||||
|
[
|
||||||
|
"Wait a moment before retrying",
|
||||||
|
"Check your API rate limits",
|
||||||
|
"Consider using smaller batch sizes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
237
cognee/infrastructure/loaders/LoaderEngine.py
Normal file
237
cognee/infrastructure/loaders/LoaderEngine.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from .LoaderInterface import LoaderInterface
|
||||||
|
from .models.LoaderResult import LoaderResult
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderEngine:
|
||||||
|
"""
|
||||||
|
Main loader engine for managing file loaders.
|
||||||
|
|
||||||
|
Follows cognee's adapter pattern similar to database engines,
|
||||||
|
providing a centralized system for file loading operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loader_directories: List[str],
|
||||||
|
default_loader_priority: List[str],
|
||||||
|
fallback_loader: str = "text_loader",
|
||||||
|
enable_dependency_validation: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the loader engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_directories: Directories to search for loader implementations
|
||||||
|
default_loader_priority: Priority order for loader selection
|
||||||
|
fallback_loader: Default loader to use when no other matches
|
||||||
|
enable_dependency_validation: Whether to validate loader dependencies
|
||||||
|
"""
|
||||||
|
self._loaders: Dict[str, LoaderInterface] = {}
|
||||||
|
self._extension_map: Dict[str, List[LoaderInterface]] = {}
|
||||||
|
self._mime_type_map: Dict[str, List[LoaderInterface]] = {}
|
||||||
|
self.loader_directories = loader_directories
|
||||||
|
self.default_loader_priority = default_loader_priority
|
||||||
|
self.fallback_loader = fallback_loader
|
||||||
|
self.enable_dependency_validation = enable_dependency_validation
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||||
|
"""
|
||||||
|
Register a loader with the engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader: LoaderInterface implementation to register
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if loader was registered successfully, False otherwise
|
||||||
|
"""
|
||||||
|
# Validate dependencies if enabled
|
||||||
|
if self.enable_dependency_validation and not loader.validate_dependencies():
|
||||||
|
self.logger.warning(
|
||||||
|
f"Skipping loader '{loader.loader_name}' - missing dependencies: "
|
||||||
|
f"{loader.get_dependencies()}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._loaders[loader.loader_name] = loader
|
||||||
|
|
||||||
|
# Map extensions to loaders
|
||||||
|
for ext in loader.supported_extensions:
|
||||||
|
ext_lower = ext.lower()
|
||||||
|
if ext_lower not in self._extension_map:
|
||||||
|
self._extension_map[ext_lower] = []
|
||||||
|
self._extension_map[ext_lower].append(loader)
|
||||||
|
|
||||||
|
# Map mime types to loaders
|
||||||
|
for mime_type in loader.supported_mime_types:
|
||||||
|
if mime_type not in self._mime_type_map:
|
||||||
|
self._mime_type_map[mime_type] = []
|
||||||
|
self._mime_type_map[mime_type].append(loader)
|
||||||
|
|
||||||
|
self.logger.info(f"Registered loader: {loader.loader_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_loader(
|
||||||
|
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None
|
||||||
|
) -> Optional[LoaderInterface]:
|
||||||
|
"""
|
||||||
|
Get appropriate loader for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
preferred_loaders: List of preferred loader names to try first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderInterface that can handle the file, or None if not found
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
# Try preferred loaders first
|
||||||
|
if preferred_loaders:
|
||||||
|
for loader_name in preferred_loaders:
|
||||||
|
if loader_name in self._loaders:
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try priority order
|
||||||
|
for loader_name in self.default_loader_priority:
|
||||||
|
if loader_name in self._loaders:
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try mime type mapping
|
||||||
|
if mime_type and mime_type in self._mime_type_map:
|
||||||
|
for loader in self._mime_type_map[mime_type]:
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try extension mapping
|
||||||
|
if ext in self._extension_map:
|
||||||
|
for loader in self._extension_map[ext]:
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Fallback loader
|
||||||
|
if self.fallback_loader in self._loaders:
|
||||||
|
fallback = self._loaders[self.fallback_loader]
|
||||||
|
if fallback.can_handle(file_path, mime_type):
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def load_file(
|
||||||
|
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None, **kwargs
|
||||||
|
) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load file using appropriate loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
preferred_loaders: List of preferred loader names to try first
|
||||||
|
**kwargs: Additional loader-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing processed content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no suitable loader is found
|
||||||
|
Exception: If file processing fails
|
||||||
|
"""
|
||||||
|
loader = self.get_loader(file_path, mime_type, preferred_loaders)
|
||||||
|
if not loader:
|
||||||
|
raise ValueError(f"No loader found for file: {file_path}")
|
||||||
|
|
||||||
|
self.logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
||||||
|
return await loader.load(file_path, **kwargs)
|
||||||
|
|
||||||
|
def discover_loaders(self):
|
||||||
|
"""
|
||||||
|
Auto-discover loaders from configured directories.
|
||||||
|
|
||||||
|
Scans loader directories for Python modules containing
|
||||||
|
LoaderInterface implementations and registers them.
|
||||||
|
"""
|
||||||
|
for directory in self.loader_directories:
|
||||||
|
if os.path.exists(directory):
|
||||||
|
self._discover_in_directory(directory)
|
||||||
|
|
||||||
|
def _discover_in_directory(self, directory: str):
|
||||||
|
"""
|
||||||
|
Discover loaders in a specific directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Directory path to scan for loader implementations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for file_name in os.listdir(directory):
|
||||||
|
if file_name.endswith(".py") and not file_name.startswith("_"):
|
||||||
|
module_name = file_name[:-3]
|
||||||
|
file_path = os.path.join(directory, file_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec and spec.loader:
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Look for loader classes
|
||||||
|
for attr_name in dir(module):
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
if (
|
||||||
|
isinstance(attr, type)
|
||||||
|
and issubclass(attr, LoaderInterface)
|
||||||
|
and attr != LoaderInterface
|
||||||
|
):
|
||||||
|
# Instantiate and register the loader
|
||||||
|
try:
|
||||||
|
loader_instance = attr()
|
||||||
|
self.register_loader(loader_instance)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Failed to instantiate loader {attr_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to load module {module_name}: {e}")
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
self.logger.warning(f"Failed to scan directory {directory}: {e}")
|
||||||
|
|
||||||
|
def get_available_loaders(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of available loader names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered loader names
|
||||||
|
"""
|
||||||
|
return list(self._loaders.keys())
|
||||||
|
|
||||||
|
def get_loader_info(self, loader_name: str) -> Dict[str, any]:
|
||||||
|
"""
|
||||||
|
Get information about a specific loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_name: Name of the loader to inspect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing loader information
|
||||||
|
"""
|
||||||
|
if loader_name not in self._loaders:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
return {
|
||||||
|
"name": loader.loader_name,
|
||||||
|
"extensions": loader.supported_extensions,
|
||||||
|
"mime_types": loader.supported_mime_types,
|
||||||
|
"dependencies": loader.get_dependencies(),
|
||||||
|
"available": loader.validate_dependencies(),
|
||||||
|
}
|
||||||
101
cognee/infrastructure/loaders/LoaderInterface.py
Normal file
101
cognee/infrastructure/loaders/LoaderInterface.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
from .models.LoaderResult import LoaderResult
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderInterface(ABC):
|
||||||
|
"""
|
||||||
|
Base interface for all file loaders in cognee.
|
||||||
|
|
||||||
|
This interface follows cognee's established pattern for database adapters,
|
||||||
|
ensuring consistent behavior across all loader implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List of file extensions this loader supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of extensions including the dot (e.g., ['.txt', '.md'])
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List of MIME types this loader supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MIME type strings (e.g., ['text/plain', 'application/pdf'])
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
"""
|
||||||
|
Unique name identifier for this loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String identifier used for registration and configuration
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this loader can handle the given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this loader can process the file, False otherwise
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load and process the file, returning standardized result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
**kwargs: Additional loader-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing processed content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If file cannot be processed
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_dependencies(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Optional: Return list of required dependencies for this loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of package names with optional version specifications
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate_dependencies(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if all required dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all dependencies are installed, False otherwise
|
||||||
|
"""
|
||||||
|
for dep in self.get_dependencies():
|
||||||
|
# Extract package name from version specification
|
||||||
|
package_name = dep.split(">=")[0].split("==")[0].split("<")[0]
|
||||||
|
try:
|
||||||
|
__import__(package_name)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
19
cognee/infrastructure/loaders/__init__.py
Normal file
19
cognee/infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
File loader infrastructure for cognee.
|
||||||
|
|
||||||
|
This package provides a plugin-based system for loading different file formats
|
||||||
|
into cognee, following the same patterns as database adapters.
|
||||||
|
|
||||||
|
Main exports:
|
||||||
|
- get_loader_engine(): Factory function to get configured loader engine
|
||||||
|
- use_loader(): Register custom loaders at runtime
|
||||||
|
- LoaderInterface: Base interface for implementing loaders
|
||||||
|
- LoaderResult, ContentType: Data models for loader results
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .get_loader_engine import get_loader_engine
|
||||||
|
from .use_loader import use_loader
|
||||||
|
from .LoaderInterface import LoaderInterface
|
||||||
|
from .models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
__all__ = ["get_loader_engine", "use_loader", "LoaderInterface", "LoaderResult", "ContentType"]
|
||||||
57
cognee/infrastructure/loaders/config.py
Normal file
57
cognee/infrastructure/loaders/config.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from cognee.root_dir import get_absolute_path
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for file loader system.
|
||||||
|
|
||||||
|
Follows cognee's pattern using pydantic_settings.BaseSettings for
|
||||||
|
environment variable support and validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loader_directories: List[str] = [
|
||||||
|
get_absolute_path("infrastructure/loaders/core"),
|
||||||
|
get_absolute_path("infrastructure/loaders/external"),
|
||||||
|
]
|
||||||
|
default_loader_priority: List[str] = [
|
||||||
|
"text_loader",
|
||||||
|
"pypdf_loader",
|
||||||
|
"unstructured_loader",
|
||||||
|
"dlt_loader",
|
||||||
|
]
|
||||||
|
auto_discover: bool = True
|
||||||
|
fallback_loader: str = "text_loader"
|
||||||
|
enable_dependency_validation: bool = True
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", extra="allow", env_prefix="LOADER_")
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert configuration to dictionary format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing all loader configuration settings
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"loader_directories": self.loader_directories,
|
||||||
|
"default_loader_priority": self.default_loader_priority,
|
||||||
|
"auto_discover": self.auto_discover,
|
||||||
|
"fallback_loader": self.fallback_loader,
|
||||||
|
"enable_dependency_validation": self.enable_dependency_validation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_loader_config() -> LoaderConfig:
|
||||||
|
"""
|
||||||
|
Get cached loader configuration.
|
||||||
|
|
||||||
|
Uses LRU cache following cognee's pattern for configuration objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderConfig instance with current settings
|
||||||
|
"""
|
||||||
|
return LoaderConfig()
|
||||||
5
cognee/infrastructure/loaders/core/__init__.py
Normal file
5
cognee/infrastructure/loaders/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Core loader implementations that are always available."""
|
||||||
|
|
||||||
|
from .text_loader import TextLoader
|
||||||
|
|
||||||
|
__all__ = ["TextLoader"]
|
||||||
128
cognee/infrastructure/loaders/core/text_loader.py
Normal file
128
cognee/infrastructure/loaders/core/text_loader.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
|
||||||
|
class TextLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
Core text file loader that handles basic text file formats.
|
||||||
|
|
||||||
|
This loader is always available and serves as the fallback for
|
||||||
|
text-based files when no specialized loader is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
"""Supported text file extensions."""
|
||||||
|
return [".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", ".log"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
"""Supported MIME types for text content."""
|
||||||
|
return [
|
||||||
|
"text/plain",
|
||||||
|
"text/markdown",
|
||||||
|
"text/csv",
|
||||||
|
"application/json",
|
||||||
|
"text/xml",
|
||||||
|
"application/xml",
|
||||||
|
"text/yaml",
|
||||||
|
"application/yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
"""Unique identifier for this loader."""
|
||||||
|
return "text_loader"
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this loader can handle the given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
mime_type: Optional MIME type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file can be handled, False otherwise
|
||||||
|
"""
|
||||||
|
# Check by extension
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if ext in self.supported_extensions:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check by MIME type
|
||||||
|
if mime_type and mime_type in self.supported_mime_types:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# As fallback loader, can attempt to handle any text-like file
|
||||||
|
# This is useful when other loaders fail
|
||||||
|
try:
|
||||||
|
# Quick check if file appears to be text
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
sample = f.read(512)
|
||||||
|
# Simple heuristic: if most bytes are printable, consider it text
|
||||||
|
if sample:
|
||||||
|
try:
|
||||||
|
sample.decode("utf-8")
|
||||||
|
return True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
sample.decode("latin-1")
|
||||||
|
return True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pass
|
||||||
|
except (OSError, IOError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load and process the text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load
|
||||||
|
encoding: Text encoding to use (default: utf-8)
|
||||||
|
**kwargs: Additional configuration (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing the file content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||||
|
OSError: If file cannot be read
|
||||||
|
"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding=encoding) as f:
|
||||||
|
content = f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Try with fallback encoding
|
||||||
|
if encoding == "utf-8":
|
||||||
|
return await self.load(file_path, encoding="latin-1", **kwargs)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Extract basic metadata
|
||||||
|
file_stat = os.stat(file_path)
|
||||||
|
metadata = {
|
||||||
|
"name": os.path.basename(file_path),
|
||||||
|
"size": file_stat.st_size,
|
||||||
|
"extension": os.path.splitext(file_path)[1],
|
||||||
|
"encoding": encoding,
|
||||||
|
"loader": self.loader_name,
|
||||||
|
"lines": len(content.splitlines()) if content else 0,
|
||||||
|
"characters": len(content),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
content_type=ContentType.TEXT,
|
||||||
|
source_info={"file_path": file_path, "encoding": encoding},
|
||||||
|
)
|
||||||
49
cognee/infrastructure/loaders/create_loader_engine.py
Normal file
49
cognee/infrastructure/loaders/create_loader_engine.py
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
from typing import List
|
||||||
|
from .LoaderEngine import LoaderEngine
|
||||||
|
from .supported_loaders import supported_loaders
|
||||||
|
|
||||||
|
|
||||||
|
def create_loader_engine(
|
||||||
|
loader_directories: List[str],
|
||||||
|
default_loader_priority: List[str],
|
||||||
|
auto_discover: bool = True,
|
||||||
|
fallback_loader: str = "text_loader",
|
||||||
|
enable_dependency_validation: bool = True,
|
||||||
|
) -> LoaderEngine:
|
||||||
|
"""
|
||||||
|
Create loader engine with given configuration.
|
||||||
|
|
||||||
|
Follows cognee's pattern for engine creation functions used
|
||||||
|
in database adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_directories: Directories to search for loader implementations
|
||||||
|
default_loader_priority: Priority order for loader selection
|
||||||
|
auto_discover: Whether to auto-discover loaders from directories
|
||||||
|
fallback_loader: Default loader to use when no other matches
|
||||||
|
enable_dependency_validation: Whether to validate loader dependencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured LoaderEngine instance
|
||||||
|
"""
|
||||||
|
engine = LoaderEngine(
|
||||||
|
loader_directories=loader_directories,
|
||||||
|
default_loader_priority=default_loader_priority,
|
||||||
|
fallback_loader=fallback_loader,
|
||||||
|
enable_dependency_validation=enable_dependency_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register supported loaders from registry
|
||||||
|
for loader_name, loader_class in supported_loaders.items():
|
||||||
|
try:
|
||||||
|
loader_instance = loader_class()
|
||||||
|
engine.register_loader(loader_instance)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - allow engine to continue with other loaders
|
||||||
|
engine.logger.warning(f"Failed to register loader {loader_name}: {e}")
|
||||||
|
|
||||||
|
# Auto-discover loaders if enabled
|
||||||
|
if auto_discover:
|
||||||
|
engine.discover_loaders()
|
||||||
|
|
||||||
|
return engine
|
||||||
34
cognee/infrastructure/loaders/external/__init__.py
vendored
Normal file
34
cognee/infrastructure/loaders/external/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
"""
|
||||||
|
External loader implementations for cognee.
|
||||||
|
|
||||||
|
This module contains loaders that depend on external libraries:
|
||||||
|
- pypdf_loader: PDF processing using pypdf
|
||||||
|
- unstructured_loader: Document processing using unstructured
|
||||||
|
- dlt_loader: Data lake/warehouse integration using DLT
|
||||||
|
|
||||||
|
These loaders are optional and only available if their dependencies are installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
# Conditional imports based on dependency availability
|
||||||
|
try:
|
||||||
|
from .pypdf_loader import PyPdfLoader
|
||||||
|
|
||||||
|
__all__.append("PyPdfLoader")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .unstructured_loader import UnstructuredLoader
|
||||||
|
|
||||||
|
__all__.append("UnstructuredLoader")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .dlt_loader import DltLoader
|
||||||
|
|
||||||
|
__all__.append("DltLoader")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
203
cognee/infrastructure/loaders/external/dlt_loader.py
vendored
Normal file
203
cognee/infrastructure/loaders/external/dlt_loader.py
vendored
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class DltLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
Data loader using DLT (Data Load Tool) for various data sources.
|
||||||
|
|
||||||
|
Supports loading data from REST APIs, databases, cloud storage,
|
||||||
|
and other data sources through DLT pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
".dlt", # DLT pipeline configuration
|
||||||
|
".json", # JSON data
|
||||||
|
".jsonl", # JSON Lines
|
||||||
|
".csv", # CSV data
|
||||||
|
".parquet", # Parquet files
|
||||||
|
".yaml", # YAML configuration
|
||||||
|
".yml", # YAML configuration
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"application/json",
|
||||||
|
"application/x-ndjson", # JSON Lines
|
||||||
|
"text/csv",
|
||||||
|
"application/x-parquet",
|
||||||
|
"application/yaml",
|
||||||
|
"text/yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
return "dlt_loader"
|
||||||
|
|
||||||
|
def get_dependencies(self) -> List[str]:
|
||||||
|
return ["dlt>=0.4.0"]
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""Check if file can be handled by this loader."""
|
||||||
|
# Check file extension
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if file_ext not in self.supported_extensions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check MIME type if provided
|
||||||
|
if mime_type and mime_type not in self.supported_mime_types:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate dependencies
|
||||||
|
return self.validate_dependencies()
|
||||||
|
|
||||||
|
async def load(self, file_path: str, source_type: str = "auto", **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load data using DLT pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the data file or DLT configuration
|
||||||
|
source_type: Type of data source ("auto", "json", "csv", "parquet", "api")
|
||||||
|
**kwargs: Additional DLT pipeline configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with loaded data and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If DLT is not installed
|
||||||
|
Exception: If data loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import dlt
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"dlt is required for data loading. Install with: pip install dlt"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Loading data with DLT: {file_path}")
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
|
# Determine source type if auto
|
||||||
|
if source_type == "auto":
|
||||||
|
if file_ext == ".json":
|
||||||
|
source_type = "json"
|
||||||
|
elif file_ext == ".jsonl":
|
||||||
|
source_type = "jsonl"
|
||||||
|
elif file_ext == ".csv":
|
||||||
|
source_type = "csv"
|
||||||
|
elif file_ext == ".parquet":
|
||||||
|
source_type = "parquet"
|
||||||
|
elif file_ext in [".yaml", ".yml"]:
|
||||||
|
source_type = "yaml"
|
||||||
|
else:
|
||||||
|
source_type = "file"
|
||||||
|
|
||||||
|
# Load data based on source type
|
||||||
|
if source_type == "json":
|
||||||
|
content = self._load_json(file_path)
|
||||||
|
elif source_type == "jsonl":
|
||||||
|
content = self._load_jsonl(file_path)
|
||||||
|
elif source_type == "csv":
|
||||||
|
content = self._load_csv(file_path)
|
||||||
|
elif source_type == "parquet":
|
||||||
|
content = self._load_parquet(file_path)
|
||||||
|
elif source_type == "yaml":
|
||||||
|
content = self._load_yaml(file_path)
|
||||||
|
else:
|
||||||
|
# Default: read as text
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Determine content type
|
||||||
|
if isinstance(content, (dict, list)):
|
||||||
|
content_type = ContentType.STRUCTURED
|
||||||
|
text_content = str(content)
|
||||||
|
else:
|
||||||
|
content_type = ContentType.TEXT
|
||||||
|
text_content = content
|
||||||
|
|
||||||
|
# Gather metadata
|
||||||
|
metadata = {
|
||||||
|
"name": file_name,
|
||||||
|
"size": file_size,
|
||||||
|
"extension": file_ext,
|
||||||
|
"loader": self.loader_name,
|
||||||
|
"source_type": source_type,
|
||||||
|
"dlt_version": dlt.__version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add data-specific metadata
|
||||||
|
if isinstance(content, list):
|
||||||
|
metadata["records_count"] = len(content)
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
metadata["keys_count"] = len(content)
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=text_content,
|
||||||
|
metadata=metadata,
|
||||||
|
content_type=content_type,
|
||||||
|
chunks=[text_content], # Single chunk for now
|
||||||
|
source_info={
|
||||||
|
"file_path": file_path,
|
||||||
|
"source_type": source_type,
|
||||||
|
"raw_data": content if isinstance(content, (dict, list)) else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to load data with DLT from {file_path}: {e}")
|
||||||
|
raise Exception(f"DLT data loading failed: {e}") from e
|
||||||
|
|
||||||
|
def _load_json(self, file_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load JSON file."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _load_jsonl(self, file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Load JSON Lines file."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = []
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _load_csv(self, file_path: str) -> str:
|
||||||
|
"""Load CSV file as text."""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def _load_parquet(self, file_path: str) -> str:
|
||||||
|
"""Load Parquet file (requires pandas)."""
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.read_parquet(file_path)
|
||||||
|
return df.to_string()
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: read as binary and convert to string representation
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
return f"<Parquet file: {os.path.basename(file_path)}, size: {len(f.read())} bytes>"
|
||||||
|
|
||||||
|
def _load_yaml(self, file_path: str) -> str:
|
||||||
|
"""Load YAML file as text."""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
127
cognee/infrastructure/loaders/external/pypdf_loader.py
vendored
Normal file
127
cognee/infrastructure/loaders/external/pypdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,127 @@
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class PyPdfLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
PDF loader using pypdf library.
|
||||||
|
|
||||||
|
Extracts text content from PDF files page by page, providing
|
||||||
|
structured page information and handling PDF-specific errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
return [".pdf"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
return ["application/pdf"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
return "pypdf_loader"
|
||||||
|
|
||||||
|
def get_dependencies(self) -> List[str]:
|
||||||
|
return ["pypdf>=4.0.0"]
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""Check if file can be handled by this loader."""
|
||||||
|
# Check file extension
|
||||||
|
if not file_path.lower().endswith(".pdf"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check MIME type if provided
|
||||||
|
if mime_type and mime_type != "application/pdf":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate dependencies
|
||||||
|
return self.validate_dependencies()
|
||||||
|
|
||||||
|
async def load(self, file_path: str, strict: bool = False, **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load PDF file and extract text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the PDF file
|
||||||
|
strict: Whether to use strict mode for PDF reading
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with extracted text content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If pypdf is not installed
|
||||||
|
Exception: If PDF processing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pypdf import PdfReader
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"pypdf is required for PDF processing. Install with: pip install pypdf"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
self.logger.info(f"Reading PDF: {file_path}")
|
||||||
|
reader = PdfReader(file, strict=strict)
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
page_texts = []
|
||||||
|
|
||||||
|
for page_num, page in enumerate(reader.pages, 1):
|
||||||
|
try:
|
||||||
|
page_text = page.extract_text()
|
||||||
|
if page_text.strip(): # Only add non-empty pages
|
||||||
|
page_texts.append(page_text)
|
||||||
|
content_parts.append(f"Page {page_num}:\n{page_text}\n")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to extract text from page {page_num}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Combine all content
|
||||||
|
full_content = "\n".join(content_parts)
|
||||||
|
|
||||||
|
# Gather metadata
|
||||||
|
metadata = {
|
||||||
|
"name": os.path.basename(file_path),
|
||||||
|
"size": os.path.getsize(file_path),
|
||||||
|
"extension": ".pdf",
|
||||||
|
"pages": len(reader.pages),
|
||||||
|
"pages_with_text": len(page_texts),
|
||||||
|
"loader": self.loader_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add PDF metadata if available
|
||||||
|
if reader.metadata:
|
||||||
|
metadata["pdf_metadata"] = {
|
||||||
|
"title": reader.metadata.get("/Title", ""),
|
||||||
|
"author": reader.metadata.get("/Author", ""),
|
||||||
|
"subject": reader.metadata.get("/Subject", ""),
|
||||||
|
"creator": reader.metadata.get("/Creator", ""),
|
||||||
|
"producer": reader.metadata.get("/Producer", ""),
|
||||||
|
"creation_date": str(reader.metadata.get("/CreationDate", "")),
|
||||||
|
"modification_date": str(reader.metadata.get("/ModDate", "")),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=full_content,
|
||||||
|
metadata=metadata,
|
||||||
|
content_type=ContentType.TEXT,
|
||||||
|
chunks=page_texts, # Pre-chunked by page
|
||||||
|
source_info={
|
||||||
|
"file_path": file_path,
|
||||||
|
"pages": len(reader.pages),
|
||||||
|
"strict_mode": strict,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to process PDF {file_path}: {e}")
|
||||||
|
raise Exception(f"PDF processing failed: {e}") from e
|
||||||
168
cognee/infrastructure/loaders/external/unstructured_loader.py
vendored
Normal file
168
cognee/infrastructure/loaders/external/unstructured_loader.py
vendored
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class UnstructuredLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
Document loader using the unstructured library.
|
||||||
|
|
||||||
|
Handles various document formats including docx, pptx, xlsx, odt, etc.
|
||||||
|
Uses the unstructured library's auto-partition functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
".docx",
|
||||||
|
".doc",
|
||||||
|
".odt", # Word documents
|
||||||
|
".xlsx",
|
||||||
|
".xls",
|
||||||
|
".ods", # Spreadsheets
|
||||||
|
".pptx",
|
||||||
|
".ppt",
|
||||||
|
".odp", # Presentations
|
||||||
|
".rtf",
|
||||||
|
".html",
|
||||||
|
".htm", # Rich text and HTML
|
||||||
|
".eml",
|
||||||
|
".msg", # Email formats
|
||||||
|
".epub", # eBooks
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx
|
||||||
|
"application/msword", # doc
|
||||||
|
"application/vnd.oasis.opendocument.text", # odt
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx
|
||||||
|
"application/vnd.ms-excel", # xls
|
||||||
|
"application/vnd.oasis.opendocument.spreadsheet", # ods
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx
|
||||||
|
"application/vnd.ms-powerpoint", # ppt
|
||||||
|
"application/vnd.oasis.opendocument.presentation", # odp
|
||||||
|
"application/rtf", # rtf
|
||||||
|
"text/html", # html
|
||||||
|
"message/rfc822", # eml
|
||||||
|
"application/epub+zip", # epub
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
return "unstructured_loader"
|
||||||
|
|
||||||
|
def get_dependencies(self) -> List[str]:
|
||||||
|
return ["unstructured>=0.10.0"]
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""Check if file can be handled by this loader."""
|
||||||
|
# Check file extension
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if file_ext not in self.supported_extensions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check MIME type if provided
|
||||||
|
if mime_type and mime_type not in self.supported_mime_types:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate dependencies
|
||||||
|
return self.validate_dependencies()
|
||||||
|
|
||||||
|
async def load(self, file_path: str, strategy: str = "auto", **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load document using unstructured library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the document file
|
||||||
|
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||||
|
**kwargs: Additional arguments passed to unstructured partition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with extracted text content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If unstructured is not installed
|
||||||
|
Exception: If document processing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from unstructured.partition.auto import partition
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"unstructured is required for document processing. "
|
||||||
|
"Install with: pip install unstructured"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Processing document: {file_path}")
|
||||||
|
|
||||||
|
# Determine content type from file extension
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
# Get file size and basic info
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
|
||||||
|
# Set partitioning parameters
|
||||||
|
partition_kwargs = {"filename": file_path, "strategy": strategy, **kwargs}
|
||||||
|
|
||||||
|
# Use partition to extract elements
|
||||||
|
elements = partition(**partition_kwargs)
|
||||||
|
|
||||||
|
# Process elements into text content
|
||||||
|
text_parts = []
|
||||||
|
element_info = []
|
||||||
|
|
||||||
|
for element in elements:
|
||||||
|
element_text = str(element).strip()
|
||||||
|
if element_text:
|
||||||
|
text_parts.append(element_text)
|
||||||
|
element_info.append(
|
||||||
|
{
|
||||||
|
"type": type(element).__name__,
|
||||||
|
"text": element_text[:100] + "..."
|
||||||
|
if len(element_text) > 100
|
||||||
|
else element_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all text content
|
||||||
|
full_content = "\n\n".join(text_parts)
|
||||||
|
|
||||||
|
# Determine content type based on structure
|
||||||
|
content_type = ContentType.STRUCTURED if len(element_info) > 1 else ContentType.TEXT
|
||||||
|
|
||||||
|
# Gather metadata
|
||||||
|
metadata = {
|
||||||
|
"name": file_name,
|
||||||
|
"size": file_size,
|
||||||
|
"extension": file_ext,
|
||||||
|
"loader": self.loader_name,
|
||||||
|
"elements_count": len(elements),
|
||||||
|
"text_elements_count": len(text_parts),
|
||||||
|
"strategy": strategy,
|
||||||
|
"element_types": list(set(info["type"] for info in element_info)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=full_content,
|
||||||
|
metadata=metadata,
|
||||||
|
content_type=content_type,
|
||||||
|
chunks=text_parts, # Pre-chunked by elements
|
||||||
|
source_info={
|
||||||
|
"file_path": file_path,
|
||||||
|
"strategy": strategy,
|
||||||
|
"elements": element_info[:10], # First 10 elements for debugging
|
||||||
|
"total_elements": len(elements),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to process document {file_path}: {e}")
|
||||||
|
raise Exception(f"Document processing failed: {e}") from e
|
||||||
20
cognee/infrastructure/loaders/get_loader_engine.py
Normal file
20
cognee/infrastructure/loaders/get_loader_engine.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
from .config import get_loader_config
|
||||||
|
from .LoaderEngine import LoaderEngine
|
||||||
|
from .create_loader_engine import create_loader_engine
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_loader_engine() -> LoaderEngine:
|
||||||
|
"""
|
||||||
|
Factory function to get loader engine.
|
||||||
|
|
||||||
|
Follows cognee's pattern with @lru_cache for efficient reuse
|
||||||
|
of engine instances. Configuration is loaded from environment
|
||||||
|
variables and settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached LoaderEngine instance configured with current settings
|
||||||
|
"""
|
||||||
|
config = get_loader_config()
|
||||||
|
return create_loader_engine(**config.to_dict())
|
||||||
47
cognee/infrastructure/loaders/models/LoaderResult.py
Normal file
47
cognee/infrastructure/loaders/models/LoaderResult.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(Enum):
|
||||||
|
"""Content type classification for loaded files"""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
STRUCTURED = "structured"
|
||||||
|
BINARY = "binary"
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderResult(BaseModel):
|
||||||
|
"""
|
||||||
|
Standardized output format for all file loaders.
|
||||||
|
|
||||||
|
This model ensures consistent data structure across all loader implementations,
|
||||||
|
following cognee's pattern of using Pydantic models for data validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str # Primary text content extracted from file
|
||||||
|
metadata: Dict[str, Any] # File metadata (name, size, type, loader info, etc.)
|
||||||
|
content_type: ContentType # Content classification
|
||||||
|
chunks: Optional[List[str]] = None # Pre-chunked content if available
|
||||||
|
source_info: Optional[Dict[str, Any]] = None # Source-specific information
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert the loader result to a dictionary format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing all loader result data with string-serialized content_type
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"content_type": self.content_type.value,
|
||||||
|
"source_info": self.source_info or {},
|
||||||
|
"chunks": self.chunks,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration following cognee patterns"""
|
||||||
|
|
||||||
|
use_enum_values = True
|
||||||
|
validate_assignment = True
|
||||||
3
cognee/infrastructure/loaders/models/__init__.py
Normal file
3
cognee/infrastructure/loaders/models/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
__all__ = ["LoaderResult", "ContentType"]
|
||||||
3
cognee/infrastructure/loaders/supported_loaders.py
Normal file
3
cognee/infrastructure/loaders/supported_loaders.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Registry for loader implementations
|
||||||
|
# Follows cognee's pattern used in databases/vector/supported_databases.py
|
||||||
|
supported_loaders = {}
|
||||||
22
cognee/infrastructure/loaders/use_loader.py
Normal file
22
cognee/infrastructure/loaders/use_loader.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
from .supported_loaders import supported_loaders
|
||||||
|
|
||||||
|
|
||||||
|
def use_loader(loader_name: str, loader_class):
|
||||||
|
"""
|
||||||
|
Register a loader at runtime.
|
||||||
|
|
||||||
|
Follows cognee's pattern used in databases for adapter registration.
|
||||||
|
This allows external packages and custom loaders to be registered
|
||||||
|
into the loader system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_name: Unique name for the loader
|
||||||
|
loader_class: Loader class implementing LoaderInterface
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from cognee.infrastructure.loaders import use_loader
|
||||||
|
from my_package import MyCustomLoader
|
||||||
|
|
||||||
|
use_loader("my_custom_loader", MyCustomLoader)
|
||||||
|
"""
|
||||||
|
supported_loaders[loader_name] = loader_class
|
||||||
|
|
@ -30,9 +30,38 @@ class BinaryData(IngestionData):
|
||||||
|
|
||||||
async def ensure_metadata(self):
|
async def ensure_metadata(self):
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
|
# Handle case where file might be closed
|
||||||
|
if hasattr(self.data, "closed") and self.data.closed:
|
||||||
|
# Try to reopen the file if we have a file path
|
||||||
|
if hasattr(self.data, "name") and self.data.name:
|
||||||
|
try:
|
||||||
|
with open(self.data.name, "rb") as reopened_file:
|
||||||
|
self.metadata = await get_file_metadata(reopened_file)
|
||||||
|
except (OSError, FileNotFoundError):
|
||||||
|
# If we can't reopen, create minimal metadata
|
||||||
|
self.metadata = {
|
||||||
|
"name": self.name or "unknown",
|
||||||
|
"file_path": getattr(self.data, "name", "unknown"),
|
||||||
|
"extension": "txt",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"content_hash": f"closed_file_{id(self.data)}",
|
||||||
|
"file_size": 0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Create minimal metadata when file is closed and no path available
|
||||||
|
self.metadata = {
|
||||||
|
"name": self.name or "unknown",
|
||||||
|
"file_path": "unknown",
|
||||||
|
"extension": "txt",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"content_hash": f"closed_file_{id(self.data)}",
|
||||||
|
"file_size": 0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# File is still open, proceed normally
|
||||||
self.metadata = await get_file_metadata(self.data)
|
self.metadata = await get_file_metadata(self.data)
|
||||||
|
|
||||||
if self.metadata["name"] is None:
|
if self.metadata.get("name") is None:
|
||||||
self.metadata["name"] = self.name
|
self.metadata["name"] = self.name
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from cognee.infrastructure.files import get_file_metadata, FileMetadata
|
from cognee.infrastructure.files import get_file_metadata, FileMetadata
|
||||||
from cognee.infrastructure.utils import run_sync
|
from cognee.infrastructure.utils.run_sync import run_sync
|
||||||
from .IngestionData import IngestionData
|
from .IngestionData import IngestionData
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,12 @@ class TextData(IngestionData):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
def get_identifier(self):
|
def get_identifier(self):
|
||||||
keywords = extract_keywords(self.data)
|
import hashlib
|
||||||
|
|
||||||
return "text/plain" + "_" + "|".join(keywords)
|
content_bytes = self.data.encode("utf-8")
|
||||||
|
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||||
|
|
||||||
|
return "text/plain" + "_" + content_hash
|
||||||
|
|
||||||
def get_metadata(self):
|
def get_metadata(self):
|
||||||
self.ensure_metadata()
|
self.ensure_metadata()
|
||||||
|
|
@ -27,7 +30,20 @@ class TextData(IngestionData):
|
||||||
|
|
||||||
def ensure_metadata(self):
|
def ensure_metadata(self):
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
self.metadata = {}
|
import hashlib
|
||||||
|
|
||||||
|
keywords = extract_keywords(self.data)
|
||||||
|
content_bytes = self.data.encode("utf-8")
|
||||||
|
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||||
|
|
||||||
|
self.metadata = {
|
||||||
|
"keywords": keywords,
|
||||||
|
"content_hash": content_hash,
|
||||||
|
"content_type": "text/plain",
|
||||||
|
"mime_type": "text/plain",
|
||||||
|
"extension": "txt",
|
||||||
|
"file_size": len(content_bytes),
|
||||||
|
}
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_data(self):
|
async def get_data(self):
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,25 @@ async def cognee_pipeline(
|
||||||
if cognee_pipeline.first_run:
|
if cognee_pipeline.first_run:
|
||||||
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
|
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
|
||||||
|
|
||||||
|
# Ensure NLTK data is downloaded on first run
|
||||||
|
def ensure_nltk_data():
|
||||||
|
"""Download required NLTK data if not already present."""
|
||||||
|
try:
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
# Download essential NLTK data used by the system
|
||||||
|
nltk.download("punkt_tab", quiet=True)
|
||||||
|
nltk.download("punkt", quiet=True)
|
||||||
|
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||||
|
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
|
||||||
|
nltk.download("maxent_ne_chunker", quiet=True)
|
||||||
|
nltk.download("words", quiet=True)
|
||||||
|
logger.info("NLTK data initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize NLTK data: {e}")
|
||||||
|
|
||||||
|
ensure_nltk_data()
|
||||||
|
|
||||||
# Test LLM and Embedding configuration once before running Cognee
|
# Test LLM and Embedding configuration once before running Cognee
|
||||||
await test_llm_connection()
|
await test_llm_connection()
|
||||||
await test_embedding_connection()
|
await test_embedding_connection()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,14 @@ import inspect
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
|
from cognee.exceptions import (
|
||||||
|
PipelineExecutionError,
|
||||||
|
CogneeTransientError,
|
||||||
|
CogneeSystemError,
|
||||||
|
CogneeUserError,
|
||||||
|
LLMConnectionError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
)
|
||||||
|
|
||||||
from ..tasks.task import Task
|
from ..tasks.task import Task
|
||||||
|
|
||||||
|
|
@ -16,15 +24,33 @@ async def handle_task(
|
||||||
user: User,
|
user: User,
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
):
|
):
|
||||||
"""Handle common task workflow with logging, telemetry, and error handling around the core execution logic."""
|
"""
|
||||||
task_type = running_task.task_type
|
Handle common task workflow with enhanced error handling and recovery strategies.
|
||||||
|
|
||||||
|
This function provides comprehensive error handling for pipeline tasks with:
|
||||||
|
- Context-aware error reporting
|
||||||
|
- Automatic retry for transient errors
|
||||||
|
- Detailed error logging and telemetry
|
||||||
|
- User-friendly error messages
|
||||||
|
"""
|
||||||
|
task_type = running_task.task_type
|
||||||
|
task_name = running_task.executable.__name__
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{task_type} task started: `{task_name}`",
|
||||||
|
extra={
|
||||||
|
"task_type": task_type,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user.id,
|
||||||
|
"context": context,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"{task_type} task started: `{running_task.executable.__name__}`")
|
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
f"{task_type} Task Started",
|
f"{task_type} Task Started",
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
additional_properties={
|
additional_properties={
|
||||||
"task_name": running_task.executable.__name__,
|
"task_name": task_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -35,36 +61,151 @@ async def handle_task(
|
||||||
if has_context:
|
if has_context:
|
||||||
args.append(context)
|
args.append(context)
|
||||||
|
|
||||||
|
# Retry configuration for transient errors
|
||||||
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while retry_count <= max_retries:
|
||||||
try:
|
try:
|
||||||
async for result_data in running_task.execute(args, next_task_batch_size):
|
async for result_data in running_task.execute(args, next_task_batch_size):
|
||||||
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
|
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`")
|
logger.info(
|
||||||
|
f"{task_type} task completed: `{task_name}`",
|
||||||
|
extra={
|
||||||
|
"task_type": task_type,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user.id,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
f"{task_type} Task Completed",
|
f"{task_type} Task Completed",
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
additional_properties={
|
additional_properties={
|
||||||
"task_name": running_task.executable.__name__,
|
"task_name": task_name,
|
||||||
|
"retry_count": retry_count,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as error:
|
return # Success, exit retry loop
|
||||||
|
|
||||||
|
except CogneeTransientError as error:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= max_retries:
|
||||||
|
logger.warning(
|
||||||
|
f"Transient error in {task_type} task `{task_name}`, retrying ({retry_count}/{max_retries}): {error}",
|
||||||
|
extra={
|
||||||
|
"task_type": task_type,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user.id,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Exponential backoff for retries
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
await asyncio.sleep(2**retry_count)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Max retries exceeded, raise enhanced error
|
||||||
|
raise PipelineExecutionError(
|
||||||
|
pipeline_name=f"{task_type}_pipeline",
|
||||||
|
task_name=task_name,
|
||||||
|
error_details=f"Max retries ({max_retries}) exceeded for transient error: {error}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except (CogneeUserError, CogneeSystemError) as error:
|
||||||
|
# These errors shouldn't be retried, re-raise as pipeline execution error
|
||||||
logger.error(
|
logger.error(
|
||||||
f"{task_type} task errored: `{running_task.executable.__name__}`\n{str(error)}\n",
|
f"{task_type} task failed: `{task_name}` - {error.__class__.__name__}: {error}",
|
||||||
|
extra={
|
||||||
|
"task_type": task_type,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user.id,
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
"error_context": getattr(error, "context", {}),
|
||||||
|
},
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
f"{task_type} Task Errored",
|
f"{task_type} Task Errored",
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
additional_properties={
|
additional_properties={
|
||||||
"task_name": running_task.executable.__name__,
|
"task_name": task_name,
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrap in pipeline execution error with additional context
|
||||||
|
raise PipelineExecutionError(
|
||||||
|
pipeline_name=f"{task_type}_pipeline",
|
||||||
|
task_name=task_name,
|
||||||
|
error_details=f"{error.__class__.__name__}: {error}",
|
||||||
|
context={
|
||||||
|
"original_error": error.__class__.__name__,
|
||||||
|
"original_context": getattr(error, "context", {}),
|
||||||
|
"user_id": user.id,
|
||||||
|
"task_args": str(args)[:200], # Truncate for logging
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
# Unexpected error, wrap in enhanced exception
|
||||||
|
logger.error(
|
||||||
|
f"{task_type} task encountered unexpected error: `{task_name}` - {error}",
|
||||||
|
extra={
|
||||||
|
"task_type": task_type,
|
||||||
|
"task_name": task_name,
|
||||||
|
"user_id": user.id,
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
},
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
send_telemetry(
|
||||||
|
f"{task_type} Task Errored",
|
||||||
|
user_id=user.id,
|
||||||
|
additional_properties={
|
||||||
|
"task_name": task_name,
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this might be a known error type we can categorize
|
||||||
|
error_message = str(error).lower()
|
||||||
|
if any(term in error_message for term in ["connection", "timeout", "network"]):
|
||||||
|
if (
|
||||||
|
"llm" in error_message
|
||||||
|
or "openai" in error_message
|
||||||
|
or "anthropic" in error_message
|
||||||
|
):
|
||||||
|
raise LLMConnectionError(provider="Unknown", model="Unknown", reason=str(error))
|
||||||
|
elif "database" in error_message or "sql" in error_message:
|
||||||
|
raise DatabaseConnectionError(db_type="Unknown", reason=str(error))
|
||||||
|
|
||||||
|
# Default to pipeline execution error
|
||||||
|
raise PipelineExecutionError(
|
||||||
|
pipeline_name=f"{task_type}_pipeline",
|
||||||
|
task_name=task_name,
|
||||||
|
error_details=f"Unexpected error: {error}",
|
||||||
|
context={
|
||||||
|
"error_type": error.__class__.__name__,
|
||||||
|
"user_id": user.id,
|
||||||
|
"task_args": str(args)[:200], # Truncate for logging
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
raise error
|
|
||||||
|
|
||||||
|
|
||||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None, context: dict = None):
|
async def run_tasks_base(tasks: list[Task], data=None, user: User = None, context: dict = None):
|
||||||
"""Base function to execute tasks in a pipeline, handling task type detection and execution."""
|
"""
|
||||||
|
Base function to execute tasks in a pipeline with enhanced error handling.
|
||||||
|
|
||||||
|
Provides comprehensive error handling, logging, and recovery strategies for pipeline execution.
|
||||||
|
"""
|
||||||
if len(tasks) == 0:
|
if len(tasks) == 0:
|
||||||
yield data
|
yield data
|
||||||
return
|
return
|
||||||
|
|
|
||||||
11
cognee/tasks/ingestion/adapters/__init__.py
Normal file
11
cognee/tasks/ingestion/adapters/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""
|
||||||
|
Adapters for bridging the new loader system with existing ingestion pipeline.
|
||||||
|
|
||||||
|
This module provides compatibility layers to integrate the plugin-based loader
|
||||||
|
system with cognee's existing data processing pipeline while maintaining
|
||||||
|
backward compatibility and preserving permission logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .loader_to_ingestion_adapter import LoaderToIngestionAdapter
|
||||||
|
|
||||||
|
__all__ = ["LoaderToIngestionAdapter"]
|
||||||
241
cognee/tasks/ingestion/adapters/loader_to_ingestion_adapter.py
Normal file
241
cognee/tasks/ingestion/adapters/loader_to_ingestion_adapter.py
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import BinaryIO, Union, Optional, Any
|
||||||
|
from io import StringIO, BytesIO
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
from cognee.modules.ingestion.data_types import IngestionData, TextData, BinaryData
|
||||||
|
from cognee.infrastructure.files import get_file_metadata
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderResultToIngestionData(IngestionData):
|
||||||
|
"""
|
||||||
|
Adapter class that wraps LoaderResult to be compatible with IngestionData interface.
|
||||||
|
|
||||||
|
This maintains backward compatibility with existing cognee ingestion pipeline
|
||||||
|
while enabling the new loader system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader_result: LoaderResult, original_file_path: str = None):
|
||||||
|
self.loader_result = loader_result
|
||||||
|
self.original_file_path = original_file_path
|
||||||
|
self._cached_metadata = None
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def get_identifier(self) -> str:
|
||||||
|
"""
|
||||||
|
Get content identifier for deduplication.
|
||||||
|
|
||||||
|
Always generates hash from content to ensure consistency with existing system.
|
||||||
|
"""
|
||||||
|
# Always generate hash from content for consistency
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
content_bytes = self.loader_result.content.encode("utf-8")
|
||||||
|
content_hash = hashlib.md5(content_bytes).hexdigest()
|
||||||
|
|
||||||
|
# Add content type prefix for better identification
|
||||||
|
content_type = self.loader_result.content_type.value
|
||||||
|
return f"{content_type}_{content_hash}"
|
||||||
|
|
||||||
|
def get_metadata(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get file metadata in the format expected by existing pipeline.
|
||||||
|
|
||||||
|
Converts LoaderResult metadata to the format used by IngestionData.
|
||||||
|
"""
|
||||||
|
if self._cached_metadata is not None:
|
||||||
|
return self._cached_metadata
|
||||||
|
|
||||||
|
# Start with loader result metadata
|
||||||
|
metadata = self.loader_result.metadata.copy()
|
||||||
|
|
||||||
|
# Ensure required fields are present
|
||||||
|
if "name" not in metadata:
|
||||||
|
if self.original_file_path:
|
||||||
|
metadata["name"] = os.path.basename(self.original_file_path)
|
||||||
|
else:
|
||||||
|
# Generate name from content hash
|
||||||
|
content_hash = self.get_identifier().split("_")[-1][:8]
|
||||||
|
ext = metadata.get("extension", ".txt")
|
||||||
|
metadata["name"] = f"content_{content_hash}{ext}"
|
||||||
|
|
||||||
|
if "content_hash" not in metadata:
|
||||||
|
# Store content hash without prefix for compatibility with deletion system
|
||||||
|
identifier = self.get_identifier()
|
||||||
|
if "_" in identifier:
|
||||||
|
# Remove content type prefix (e.g., "text_abc123" -> "abc123")
|
||||||
|
metadata["content_hash"] = identifier.split("_", 1)[-1]
|
||||||
|
else:
|
||||||
|
metadata["content_hash"] = identifier
|
||||||
|
|
||||||
|
if "file_path" not in metadata and self.original_file_path:
|
||||||
|
metadata["file_path"] = self.original_file_path
|
||||||
|
|
||||||
|
# Add mime type if not present
|
||||||
|
if "mime_type" not in metadata:
|
||||||
|
ext = metadata.get("extension", "").lower()
|
||||||
|
mime_type_map = {
|
||||||
|
".txt": "text/plain",
|
||||||
|
".md": "text/markdown",
|
||||||
|
".csv": "text/csv",
|
||||||
|
".json": "application/json",
|
||||||
|
".pdf": "application/pdf",
|
||||||
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||||
|
}
|
||||||
|
metadata["mime_type"] = mime_type_map.get(ext, "application/octet-stream")
|
||||||
|
|
||||||
|
self._cached_metadata = metadata
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def get_data(self) -> Union[str, BinaryIO]:
|
||||||
|
"""
|
||||||
|
Get data content in format expected by existing pipeline.
|
||||||
|
|
||||||
|
Returns content as string for text data or creates a file-like object
|
||||||
|
for binary data to maintain compatibility.
|
||||||
|
"""
|
||||||
|
if self.loader_result.content_type == ContentType.TEXT:
|
||||||
|
return self.loader_result.content
|
||||||
|
|
||||||
|
# For structured or binary content, return as string for now
|
||||||
|
# The existing pipeline expects text content for processing
|
||||||
|
return self.loader_result.content
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderToIngestionAdapter:
|
||||||
|
"""
|
||||||
|
Adapter that bridges the new loader system with existing ingestion pipeline.
|
||||||
|
|
||||||
|
This class provides methods to process files using the loader system
|
||||||
|
while maintaining compatibility with the existing IngestionData interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
async def process_file_with_loaders(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
s3fs: Optional[Any] = None,
|
||||||
|
preferred_loaders: Optional[list] = None,
|
||||||
|
loader_config: Optional[dict] = None,
|
||||||
|
) -> IngestionData:
|
||||||
|
"""
|
||||||
|
Process a file using the loader system and return IngestionData.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to process
|
||||||
|
s3fs: S3 filesystem (for compatibility with existing code)
|
||||||
|
preferred_loaders: List of preferred loader names
|
||||||
|
loader_config: Configuration for specific loaders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IngestionData compatible object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If no loader can handle the file
|
||||||
|
"""
|
||||||
|
from cognee.infrastructure.loaders import get_loader_engine
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the loader engine
|
||||||
|
engine = get_loader_engine()
|
||||||
|
|
||||||
|
# Determine MIME type if possible
|
||||||
|
mime_type = None
|
||||||
|
try:
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Load file using loader system
|
||||||
|
self.logger.info(f"Processing file with loaders: {file_path}")
|
||||||
|
|
||||||
|
# Extract loader-specific config if provided
|
||||||
|
kwargs = {}
|
||||||
|
if loader_config:
|
||||||
|
# Find the first available loader that matches our preferred loaders
|
||||||
|
loader = engine.get_loader(file_path, mime_type, preferred_loaders)
|
||||||
|
if loader and loader.loader_name in loader_config:
|
||||||
|
kwargs = loader_config[loader.loader_name]
|
||||||
|
|
||||||
|
loader_result = await engine.load_file(
|
||||||
|
file_path, mime_type=mime_type, preferred_loaders=preferred_loaders, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to IngestionData compatible format
|
||||||
|
return LoaderResultToIngestionData(loader_result, file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Loader system failed for {file_path}: {e}")
|
||||||
|
# Fallback to existing classification system
|
||||||
|
return await self._fallback_to_existing_system(file_path, s3fs)
|
||||||
|
|
||||||
|
async def _fallback_to_existing_system(
|
||||||
|
self, file_path: str, s3fs: Optional[Any] = None
|
||||||
|
) -> IngestionData:
|
||||||
|
"""
|
||||||
|
Fallback to existing ingestion.classify() system for backward compatibility.
|
||||||
|
|
||||||
|
This ensures that even if the loader system fails, we can still process
|
||||||
|
files using the original classification method.
|
||||||
|
"""
|
||||||
|
from cognee.modules.ingestion import classify
|
||||||
|
|
||||||
|
self.logger.info(f"Falling back to existing classification system for: {file_path}")
|
||||||
|
|
||||||
|
# Open file and classify using existing system
|
||||||
|
if file_path.startswith("s3://"):
|
||||||
|
if s3fs:
|
||||||
|
with s3fs.open(file_path, "rb") as file:
|
||||||
|
return classify(file)
|
||||||
|
else:
|
||||||
|
raise ValueError("S3 file path provided but no s3fs available")
|
||||||
|
else:
|
||||||
|
# Handle local files and file:// URLs
|
||||||
|
local_path = file_path.replace("file://", "")
|
||||||
|
with open(local_path, "rb") as file:
|
||||||
|
return classify(file)
|
||||||
|
|
||||||
|
def is_text_content(self, data: Union[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the provided data is text content (not a file path).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if data is text content, False if it's a file path
|
||||||
|
"""
|
||||||
|
if not isinstance(data, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if it's a file path
|
||||||
|
if (
|
||||||
|
data.startswith("/")
|
||||||
|
or data.startswith("file://")
|
||||||
|
or data.startswith("s3://")
|
||||||
|
or (len(data) > 1 and data[1] == ":")
|
||||||
|
): # Windows drive paths
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_text_ingestion_data(self, content: str) -> IngestionData:
|
||||||
|
"""
|
||||||
|
Create IngestionData for text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to wrap
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IngestionData compatible object
|
||||||
|
"""
|
||||||
|
|
||||||
|
return TextData(content)
|
||||||
|
|
@ -60,7 +60,7 @@ async def ingest_data(
|
||||||
else:
|
else:
|
||||||
# Find existing dataset or create a new one
|
# Find existing dataset or create a new one
|
||||||
existing_datasets = await get_authorized_existing_datasets(
|
existing_datasets = await get_authorized_existing_datasets(
|
||||||
user=user, permission_type="write", datasets=[dataset_name]
|
datasets=[dataset_name], permission_type="write", user=user
|
||||||
)
|
)
|
||||||
dataset = await load_or_create_datasets(
|
dataset = await load_or_create_datasets(
|
||||||
dataset_names=[dataset_name],
|
dataset_names=[dataset_name],
|
||||||
|
|
|
||||||
309
cognee/tasks/ingestion/plugin_ingest_data.py
Normal file
309
cognee/tasks/ingestion/plugin_ingest_data.py
Normal file
|
|
@ -0,0 +1,309 @@
|
||||||
|
import json
|
||||||
|
import inspect
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Union, BinaryIO, Any, List, Optional
|
||||||
|
|
||||||
|
import cognee.modules.ingestion as ingestion
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||||
|
from cognee.modules.data.methods import (
|
||||||
|
get_authorized_existing_datasets,
|
||||||
|
get_dataset_data,
|
||||||
|
load_or_create_datasets,
|
||||||
|
)
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
from .save_data_item_to_storage import save_data_item_to_storage
|
||||||
|
from .adapters import LoaderToIngestionAdapter
|
||||||
|
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def plugin_ingest_data(
|
||||||
|
data: Any,
|
||||||
|
dataset_name: str,
|
||||||
|
user: User,
|
||||||
|
node_set: Optional[List[str]] = None,
|
||||||
|
dataset_id: UUID = None,
|
||||||
|
preferred_loaders: Optional[List[str]] = None,
|
||||||
|
loader_config: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Plugin-based data ingestion using the loader system.
|
||||||
|
|
||||||
|
This function maintains full backward compatibility with the existing
|
||||||
|
ingest_data function while adding support for the new loader system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data to ingest
|
||||||
|
dataset_name: Name of the dataset
|
||||||
|
user: User object for permissions
|
||||||
|
node_set: Optional node set for organization
|
||||||
|
dataset_id: Optional specific dataset ID
|
||||||
|
preferred_loaders: List of preferred loader names to try first
|
||||||
|
loader_config: Configuration for specific loaders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Data objects that were ingested
|
||||||
|
"""
|
||||||
|
if not user:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
# Ensure NLTK data is downloaded (preserves automatic download behavior)
|
||||||
|
def ensure_nltk_data():
|
||||||
|
"""Download required NLTK data if not already present."""
|
||||||
|
try:
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
# Download essential NLTK data used by the system
|
||||||
|
nltk.download("punkt_tab", quiet=True)
|
||||||
|
nltk.download("punkt", quiet=True)
|
||||||
|
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||||
|
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
|
||||||
|
nltk.download("maxent_ne_chunker", quiet=True)
|
||||||
|
nltk.download("words", quiet=True)
|
||||||
|
logger.info("NLTK data verified/downloaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to download NLTK data: {e}")
|
||||||
|
|
||||||
|
# Download NLTK data once per session
|
||||||
|
if not hasattr(plugin_ingest_data, "_nltk_initialized"):
|
||||||
|
ensure_nltk_data()
|
||||||
|
plugin_ingest_data._nltk_initialized = True
|
||||||
|
|
||||||
|
# Initialize S3 support (maintain existing behavior)
|
||||||
|
s3_config = get_s3_config()
|
||||||
|
fs = None
|
||||||
|
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
||||||
|
import s3fs
|
||||||
|
|
||||||
|
fs = s3fs.S3FileSystem(
|
||||||
|
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the loader adapter
|
||||||
|
loader_adapter = LoaderToIngestionAdapter()
|
||||||
|
|
||||||
|
def open_data_file(file_path: str):
|
||||||
|
"""Open file with S3 support (preserves existing behavior)."""
|
||||||
|
if file_path.startswith("s3://"):
|
||||||
|
return fs.open(file_path, mode="rb")
|
||||||
|
else:
|
||||||
|
local_path = file_path.replace("file://", "")
|
||||||
|
return open(local_path, mode="rb")
|
||||||
|
|
||||||
|
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
|
||||||
|
"""Get external metadata (preserves existing behavior)."""
|
||||||
|
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||||
|
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def store_data_to_dataset(
|
||||||
|
data: Any,
|
||||||
|
dataset_name: str,
|
||||||
|
user: User,
|
||||||
|
node_set: Optional[List[str]] = None,
|
||||||
|
dataset_id: UUID = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Core data storage logic with plugin-based file processing.
|
||||||
|
|
||||||
|
This function preserves all existing permission and database logic
|
||||||
|
while using the new loader system for file processing.
|
||||||
|
"""
|
||||||
|
logger.info(f"Plugin-based ingestion starting for dataset: {dataset_name}")
|
||||||
|
|
||||||
|
# Preserve existing dataset creation and permission logic
|
||||||
|
if dataset_id:
|
||||||
|
# Retrieve existing dataset by ID
|
||||||
|
dataset = await get_specific_user_permission_datasets(user.id, "write", [dataset_id])
|
||||||
|
# Convert from list to Dataset element
|
||||||
|
if isinstance(dataset, list):
|
||||||
|
dataset = dataset[0]
|
||||||
|
else:
|
||||||
|
# Find existing dataset or create a new one by name
|
||||||
|
existing_datasets = await get_authorized_existing_datasets(
|
||||||
|
datasets=[dataset_name], permission_type="write", user=user
|
||||||
|
)
|
||||||
|
datasets = await load_or_create_datasets(
|
||||||
|
dataset_names=[dataset_name], existing_datasets=existing_datasets, user=user
|
||||||
|
)
|
||||||
|
if isinstance(datasets, list):
|
||||||
|
dataset = datasets[0]
|
||||||
|
|
||||||
|
new_datapoints = []
|
||||||
|
existing_data_points = []
|
||||||
|
dataset_new_data_points = []
|
||||||
|
|
||||||
|
# Get existing dataset data for deduplication (preserve existing logic)
|
||||||
|
dataset_data: list[Data] = await get_dataset_data(dataset.id)
|
||||||
|
dataset_data_map = {str(data.id): True for data in dataset_data}
|
||||||
|
|
||||||
|
for data_item in data:
|
||||||
|
file_path = await save_data_item_to_storage(data_item)
|
||||||
|
|
||||||
|
# NEW: Use loader system or existing classification based on data type
|
||||||
|
try:
|
||||||
|
if loader_adapter.is_text_content(data_item):
|
||||||
|
# Handle text content (preserve existing behavior)
|
||||||
|
logger.info("Processing text content with existing system")
|
||||||
|
classified_data = ingestion.classify(data_item)
|
||||||
|
else:
|
||||||
|
# Use loader system for file paths
|
||||||
|
logger.info(f"Processing file with loader system: {file_path}")
|
||||||
|
classified_data = await loader_adapter.process_file_with_loaders(
|
||||||
|
file_path,
|
||||||
|
s3fs=fs,
|
||||||
|
preferred_loaders=preferred_loaders,
|
||||||
|
loader_config=loader_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Plugin system failed for {file_path}, falling back: {e}")
|
||||||
|
# Fallback to existing system for full backward compatibility
|
||||||
|
with open_data_file(file_path) as file:
|
||||||
|
classified_data = ingestion.classify(file)
|
||||||
|
|
||||||
|
# Preserve all existing data processing logic
|
||||||
|
data_id = ingestion.identify(classified_data, user)
|
||||||
|
file_metadata = classified_data.get_metadata()
|
||||||
|
|
||||||
|
# Ensure metadata has all required fields with fallbacks
|
||||||
|
def get_metadata_field(metadata, field_name, default_value=""):
|
||||||
|
"""Get metadata field with fallback handling."""
|
||||||
|
if field_name in metadata and metadata[field_name] is not None:
|
||||||
|
return metadata[field_name]
|
||||||
|
|
||||||
|
logger.warning(f"Missing metadata field '{field_name}', using fallback")
|
||||||
|
|
||||||
|
# Provide fallbacks based on available information
|
||||||
|
if field_name == "name":
|
||||||
|
if "file_path" in metadata and metadata["file_path"]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
return os.path.basename(str(metadata["file_path"])).split(".")[0]
|
||||||
|
elif file_path:
|
||||||
|
import os
|
||||||
|
|
||||||
|
return os.path.basename(str(file_path)).split(".")[0]
|
||||||
|
else:
|
||||||
|
content_hash = metadata.get("content_hash", str(data_id))[:8]
|
||||||
|
return f"content_{content_hash}"
|
||||||
|
elif field_name == "file_path":
|
||||||
|
# Use the actual file path returned by save_data_item_to_storage
|
||||||
|
return file_path
|
||||||
|
elif field_name == "extension":
|
||||||
|
if "file_path" in metadata and metadata["file_path"]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(str(metadata["file_path"]))
|
||||||
|
return ext.lstrip(".") if ext else "txt"
|
||||||
|
elif file_path:
|
||||||
|
import os
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(str(file_path))
|
||||||
|
return ext.lstrip(".") if ext else "txt"
|
||||||
|
return "txt"
|
||||||
|
elif field_name == "mime_type":
|
||||||
|
ext = get_metadata_field(metadata, "extension", "txt")
|
||||||
|
mime_map = {
|
||||||
|
"txt": "text/plain",
|
||||||
|
"md": "text/markdown",
|
||||||
|
"pdf": "application/pdf",
|
||||||
|
"json": "application/json",
|
||||||
|
"csv": "text/csv",
|
||||||
|
}
|
||||||
|
return mime_map.get(ext.lower(), "text/plain")
|
||||||
|
elif field_name == "content_hash":
|
||||||
|
# Extract the raw content hash for compatibility with deletion system
|
||||||
|
content_identifier = classified_data.get_identifier()
|
||||||
|
# Remove content type prefix if present (e.g., "text_abc123" -> "abc123")
|
||||||
|
if "_" in content_identifier:
|
||||||
|
return content_identifier.split("_", 1)[-1]
|
||||||
|
return content_identifier
|
||||||
|
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
# Check if data should be updated (preserve existing logic)
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
data_point = (
|
||||||
|
await session.execute(select(Data).filter(Data.id == data_id))
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
ext_metadata = get_external_metadata_dict(data_item)
|
||||||
|
|
||||||
|
if node_set:
|
||||||
|
ext_metadata["node_set"] = node_set
|
||||||
|
|
||||||
|
# Preserve existing data point creation/update logic
|
||||||
|
if data_point is not None:
|
||||||
|
data_point.name = get_metadata_field(file_metadata, "name")
|
||||||
|
data_point.raw_data_location = get_metadata_field(file_metadata, "file_path")
|
||||||
|
data_point.extension = get_metadata_field(file_metadata, "extension")
|
||||||
|
data_point.mime_type = get_metadata_field(file_metadata, "mime_type")
|
||||||
|
data_point.owner_id = user.id
|
||||||
|
data_point.content_hash = get_metadata_field(file_metadata, "content_hash")
|
||||||
|
data_point.external_metadata = ext_metadata
|
||||||
|
data_point.node_set = json.dumps(node_set) if node_set else None
|
||||||
|
|
||||||
|
if str(data_point.id) in dataset_data_map:
|
||||||
|
existing_data_points.append(data_point)
|
||||||
|
else:
|
||||||
|
dataset_new_data_points.append(data_point)
|
||||||
|
dataset_data_map[str(data_point.id)] = True
|
||||||
|
else:
|
||||||
|
if str(data_id) in dataset_data_map:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_point = Data(
|
||||||
|
id=data_id,
|
||||||
|
name=get_metadata_field(file_metadata, "name"),
|
||||||
|
raw_data_location=get_metadata_field(file_metadata, "file_path"),
|
||||||
|
extension=get_metadata_field(file_metadata, "extension"),
|
||||||
|
mime_type=get_metadata_field(file_metadata, "mime_type"),
|
||||||
|
owner_id=user.id,
|
||||||
|
content_hash=get_metadata_field(file_metadata, "content_hash"),
|
||||||
|
external_metadata=ext_metadata,
|
||||||
|
node_set=json.dumps(node_set) if node_set else None,
|
||||||
|
token_count=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_datapoints.append(data_point)
|
||||||
|
dataset_data_map[str(data_point.id)] = True
|
||||||
|
|
||||||
|
# Preserve existing database operations
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
if dataset not in session:
|
||||||
|
session.add(dataset)
|
||||||
|
|
||||||
|
if len(new_datapoints) > 0:
|
||||||
|
dataset.data.extend(new_datapoints)
|
||||||
|
|
||||||
|
if len(existing_data_points) > 0:
|
||||||
|
for data_point in existing_data_points:
|
||||||
|
await session.merge(data_point)
|
||||||
|
|
||||||
|
if len(dataset_new_data_points) > 0:
|
||||||
|
dataset.data.extend(dataset_new_data_points)
|
||||||
|
|
||||||
|
await session.merge(dataset)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Plugin-based ingestion completed. New: {len(new_datapoints)}, "
|
||||||
|
+ f"Updated: {len(existing_data_points)}, Dataset new: {len(dataset_new_data_points)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return existing_data_points + dataset_new_data_points + new_datapoints
|
||||||
|
|
||||||
|
return await store_data_to_dataset(data, dataset_name, user, node_set, dataset_id)
|
||||||
237
infrastructure/loaders/LoaderEngine.py
Normal file
237
infrastructure/loaders/LoaderEngine.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from .LoaderInterface import LoaderInterface
|
||||||
|
from .models.LoaderResult import LoaderResult
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderEngine:
|
||||||
|
"""
|
||||||
|
Main loader engine for managing file loaders.
|
||||||
|
|
||||||
|
Follows cognee's adapter pattern similar to database engines,
|
||||||
|
providing a centralized system for file loading operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
loader_directories: List[str],
|
||||||
|
default_loader_priority: List[str],
|
||||||
|
fallback_loader: str = "text_loader",
|
||||||
|
enable_dependency_validation: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the loader engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_directories: Directories to search for loader implementations
|
||||||
|
default_loader_priority: Priority order for loader selection
|
||||||
|
fallback_loader: Default loader to use when no other matches
|
||||||
|
enable_dependency_validation: Whether to validate loader dependencies
|
||||||
|
"""
|
||||||
|
self._loaders: Dict[str, LoaderInterface] = {}
|
||||||
|
self._extension_map: Dict[str, List[LoaderInterface]] = {}
|
||||||
|
self._mime_type_map: Dict[str, List[LoaderInterface]] = {}
|
||||||
|
self.loader_directories = loader_directories
|
||||||
|
self.default_loader_priority = default_loader_priority
|
||||||
|
self.fallback_loader = fallback_loader
|
||||||
|
self.enable_dependency_validation = enable_dependency_validation
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def register_loader(self, loader: LoaderInterface) -> bool:
|
||||||
|
"""
|
||||||
|
Register a loader with the engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader: LoaderInterface implementation to register
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if loader was registered successfully, False otherwise
|
||||||
|
"""
|
||||||
|
# Validate dependencies if enabled
|
||||||
|
if self.enable_dependency_validation and not loader.validate_dependencies():
|
||||||
|
self.logger.warning(
|
||||||
|
f"Skipping loader '{loader.loader_name}' - missing dependencies: "
|
||||||
|
f"{loader.get_dependencies()}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._loaders[loader.loader_name] = loader
|
||||||
|
|
||||||
|
# Map extensions to loaders
|
||||||
|
for ext in loader.supported_extensions:
|
||||||
|
ext_lower = ext.lower()
|
||||||
|
if ext_lower not in self._extension_map:
|
||||||
|
self._extension_map[ext_lower] = []
|
||||||
|
self._extension_map[ext_lower].append(loader)
|
||||||
|
|
||||||
|
# Map mime types to loaders
|
||||||
|
for mime_type in loader.supported_mime_types:
|
||||||
|
if mime_type not in self._mime_type_map:
|
||||||
|
self._mime_type_map[mime_type] = []
|
||||||
|
self._mime_type_map[mime_type].append(loader)
|
||||||
|
|
||||||
|
self.logger.info(f"Registered loader: {loader.loader_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_loader(
|
||||||
|
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None
|
||||||
|
) -> Optional[LoaderInterface]:
|
||||||
|
"""
|
||||||
|
Get appropriate loader for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
preferred_loaders: List of preferred loader names to try first
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderInterface that can handle the file, or None if not found
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
# Try preferred loaders first
|
||||||
|
if preferred_loaders:
|
||||||
|
for loader_name in preferred_loaders:
|
||||||
|
if loader_name in self._loaders:
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try priority order
|
||||||
|
for loader_name in self.default_loader_priority:
|
||||||
|
if loader_name in self._loaders:
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try mime type mapping
|
||||||
|
if mime_type and mime_type in self._mime_type_map:
|
||||||
|
for loader in self._mime_type_map[mime_type]:
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Try extension mapping
|
||||||
|
if ext in self._extension_map:
|
||||||
|
for loader in self._extension_map[ext]:
|
||||||
|
if loader.can_handle(file_path, mime_type):
|
||||||
|
return loader
|
||||||
|
|
||||||
|
# Fallback loader
|
||||||
|
if self.fallback_loader in self._loaders:
|
||||||
|
fallback = self._loaders[self.fallback_loader]
|
||||||
|
if fallback.can_handle(file_path, mime_type):
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def load_file(
|
||||||
|
self, file_path: str, mime_type: str = None, preferred_loaders: List[str] = None, **kwargs
|
||||||
|
) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load file using appropriate loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
preferred_loaders: List of preferred loader names to try first
|
||||||
|
**kwargs: Additional loader-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing processed content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no suitable loader is found
|
||||||
|
Exception: If file processing fails
|
||||||
|
"""
|
||||||
|
loader = self.get_loader(file_path, mime_type, preferred_loaders)
|
||||||
|
if not loader:
|
||||||
|
raise ValueError(f"No loader found for file: {file_path}")
|
||||||
|
|
||||||
|
self.logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
||||||
|
return await loader.load(file_path, **kwargs)
|
||||||
|
|
||||||
|
def discover_loaders(self):
|
||||||
|
"""
|
||||||
|
Auto-discover loaders from configured directories.
|
||||||
|
|
||||||
|
Scans loader directories for Python modules containing
|
||||||
|
LoaderInterface implementations and registers them.
|
||||||
|
"""
|
||||||
|
for directory in self.loader_directories:
|
||||||
|
if os.path.exists(directory):
|
||||||
|
self._discover_in_directory(directory)
|
||||||
|
|
||||||
|
def _discover_in_directory(self, directory: str):
|
||||||
|
"""
|
||||||
|
Discover loaders in a specific directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Directory path to scan for loader implementations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for file_name in os.listdir(directory):
|
||||||
|
if file_name.endswith(".py") and not file_name.startswith("_"):
|
||||||
|
module_name = file_name[:-3]
|
||||||
|
file_path = os.path.join(directory, file_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec and spec.loader:
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Look for loader classes
|
||||||
|
for attr_name in dir(module):
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
if (
|
||||||
|
isinstance(attr, type)
|
||||||
|
and issubclass(attr, LoaderInterface)
|
||||||
|
and attr != LoaderInterface
|
||||||
|
):
|
||||||
|
# Instantiate and register the loader
|
||||||
|
try:
|
||||||
|
loader_instance = attr()
|
||||||
|
self.register_loader(loader_instance)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Failed to instantiate loader {attr_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to load module {module_name}: {e}")
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
self.logger.warning(f"Failed to scan directory {directory}: {e}")
|
||||||
|
|
||||||
|
def get_available_loaders(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of available loader names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered loader names
|
||||||
|
"""
|
||||||
|
return list(self._loaders.keys())
|
||||||
|
|
||||||
|
def get_loader_info(self, loader_name: str) -> Dict[str, any]:
|
||||||
|
"""
|
||||||
|
Get information about a specific loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_name: Name of the loader to inspect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing loader information
|
||||||
|
"""
|
||||||
|
if loader_name not in self._loaders:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
loader = self._loaders[loader_name]
|
||||||
|
return {
|
||||||
|
"name": loader.loader_name,
|
||||||
|
"extensions": loader.supported_extensions,
|
||||||
|
"mime_types": loader.supported_mime_types,
|
||||||
|
"dependencies": loader.get_dependencies(),
|
||||||
|
"available": loader.validate_dependencies(),
|
||||||
|
}
|
||||||
101
infrastructure/loaders/LoaderInterface.py
Normal file
101
infrastructure/loaders/LoaderInterface.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
from .models.LoaderResult import LoaderResult
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderInterface(ABC):
|
||||||
|
"""
|
||||||
|
Base interface for all file loaders in cognee.
|
||||||
|
|
||||||
|
This interface follows cognee's established pattern for database adapters,
|
||||||
|
ensuring consistent behavior across all loader implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List of file extensions this loader supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of extensions including the dot (e.g., ['.txt', '.md'])
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List of MIME types this loader supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MIME type strings (e.g., ['text/plain', 'application/pdf'])
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
"""
|
||||||
|
Unique name identifier for this loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String identifier used for registration and configuration
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this loader can handle the given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
mime_type: Optional MIME type of the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this loader can process the file, False otherwise
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load and process the file, returning standardized result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to be processed
|
||||||
|
**kwargs: Additional loader-specific configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing processed content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If file cannot be processed
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_dependencies(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Optional: Return list of required dependencies for this loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of package names with optional version specifications
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate_dependencies(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if all required dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all dependencies are installed, False otherwise
|
||||||
|
"""
|
||||||
|
for dep in self.get_dependencies():
|
||||||
|
# Extract package name from version specification
|
||||||
|
package_name = dep.split(">=")[0].split("==")[0].split("<")[0]
|
||||||
|
try:
|
||||||
|
__import__(package_name)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
19
infrastructure/loaders/__init__.py
Normal file
19
infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
File loader infrastructure for cognee.
|
||||||
|
|
||||||
|
This package provides a plugin-based system for loading different file formats
|
||||||
|
into cognee, following the same patterns as database adapters.
|
||||||
|
|
||||||
|
Main exports:
|
||||||
|
- get_loader_engine(): Factory function to get configured loader engine
|
||||||
|
- use_loader(): Register custom loaders at runtime
|
||||||
|
- LoaderInterface: Base interface for implementing loaders
|
||||||
|
- LoaderResult, ContentType: Data models for loader results
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .get_loader_engine import get_loader_engine
|
||||||
|
from .use_loader import use_loader
|
||||||
|
from .LoaderInterface import LoaderInterface
|
||||||
|
from .models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
__all__ = ["get_loader_engine", "use_loader", "LoaderInterface", "LoaderResult", "ContentType"]
|
||||||
57
infrastructure/loaders/config.py
Normal file
57
infrastructure/loaders/config.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from cognee.root_dir import get_absolute_path
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for file loader system.
|
||||||
|
|
||||||
|
Follows cognee's pattern using pydantic_settings.BaseSettings for
|
||||||
|
environment variable support and validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loader_directories: List[str] = [
|
||||||
|
get_absolute_path("cognee/infrastructure/loaders/core"),
|
||||||
|
get_absolute_path("cognee/infrastructure/loaders/external"),
|
||||||
|
]
|
||||||
|
default_loader_priority: List[str] = [
|
||||||
|
"text_loader",
|
||||||
|
"pypdf_loader",
|
||||||
|
"unstructured_loader",
|
||||||
|
"dlt_loader",
|
||||||
|
]
|
||||||
|
auto_discover: bool = True
|
||||||
|
fallback_loader: str = "text_loader"
|
||||||
|
enable_dependency_validation: bool = True
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", extra="allow", env_prefix="LOADER_")
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert configuration to dictionary format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing all loader configuration settings
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"loader_directories": self.loader_directories,
|
||||||
|
"default_loader_priority": self.default_loader_priority,
|
||||||
|
"auto_discover": self.auto_discover,
|
||||||
|
"fallback_loader": self.fallback_loader,
|
||||||
|
"enable_dependency_validation": self.enable_dependency_validation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_loader_config() -> LoaderConfig:
|
||||||
|
"""
|
||||||
|
Get cached loader configuration.
|
||||||
|
|
||||||
|
Uses LRU cache following cognee's pattern for configuration objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderConfig instance with current settings
|
||||||
|
"""
|
||||||
|
return LoaderConfig()
|
||||||
5
infrastructure/loaders/core/__init__.py
Normal file
5
infrastructure/loaders/core/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Core loader implementations that are always available."""
|
||||||
|
|
||||||
|
from .text_loader import TextLoader
|
||||||
|
|
||||||
|
__all__ = ["TextLoader"]
|
||||||
128
infrastructure/loaders/core/text_loader.py
Normal file
128
infrastructure/loaders/core/text_loader.py
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from ..LoaderInterface import LoaderInterface
|
||||||
|
from ..models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
|
||||||
|
class TextLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
Core text file loader that handles basic text file formats.
|
||||||
|
|
||||||
|
This loader is always available and serves as the fallback for
|
||||||
|
text-based files when no specialized loader is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
"""Supported text file extensions."""
|
||||||
|
return [".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", ".log"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
"""Supported MIME types for text content."""
|
||||||
|
return [
|
||||||
|
"text/plain",
|
||||||
|
"text/markdown",
|
||||||
|
"text/csv",
|
||||||
|
"application/json",
|
||||||
|
"text/xml",
|
||||||
|
"application/xml",
|
||||||
|
"text/yaml",
|
||||||
|
"application/yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
"""Unique identifier for this loader."""
|
||||||
|
return "text_loader"
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this loader can handle the given file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
mime_type: Optional MIME type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file can be handled, False otherwise
|
||||||
|
"""
|
||||||
|
# Check by extension
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if ext in self.supported_extensions:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check by MIME type
|
||||||
|
if mime_type and mime_type in self.supported_mime_types:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# As fallback loader, can attempt to handle any text-like file
|
||||||
|
# This is useful when other loaders fail
|
||||||
|
try:
|
||||||
|
# Quick check if file appears to be text
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
sample = f.read(512)
|
||||||
|
# Simple heuristic: if most bytes are printable, consider it text
|
||||||
|
if sample:
|
||||||
|
try:
|
||||||
|
sample.decode("utf-8")
|
||||||
|
return True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
sample.decode("latin-1")
|
||||||
|
return True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pass
|
||||||
|
except (OSError, IOError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs) -> LoaderResult:
|
||||||
|
"""
|
||||||
|
Load and process the text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load
|
||||||
|
encoding: Text encoding to use (default: utf-8)
|
||||||
|
**kwargs: Additional configuration (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult containing the file content and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||||
|
OSError: If file cannot be read
|
||||||
|
"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding=encoding) as f:
|
||||||
|
content = f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Try with fallback encoding
|
||||||
|
if encoding == "utf-8":
|
||||||
|
return await self.load(file_path, encoding="latin-1", **kwargs)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Extract basic metadata
|
||||||
|
file_stat = os.stat(file_path)
|
||||||
|
metadata = {
|
||||||
|
"name": os.path.basename(file_path),
|
||||||
|
"size": file_stat.st_size,
|
||||||
|
"extension": os.path.splitext(file_path)[1],
|
||||||
|
"encoding": encoding,
|
||||||
|
"loader": self.loader_name,
|
||||||
|
"lines": len(content.splitlines()) if content else 0,
|
||||||
|
"characters": len(content),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoaderResult(
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
content_type=ContentType.TEXT,
|
||||||
|
source_info={"file_path": file_path, "encoding": encoding},
|
||||||
|
)
|
||||||
49
infrastructure/loaders/create_loader_engine.py
Normal file
49
infrastructure/loaders/create_loader_engine.py
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
from typing import List
|
||||||
|
from .LoaderEngine import LoaderEngine
|
||||||
|
from .supported_loaders import supported_loaders
|
||||||
|
|
||||||
|
|
||||||
|
def create_loader_engine(
|
||||||
|
loader_directories: List[str],
|
||||||
|
default_loader_priority: List[str],
|
||||||
|
auto_discover: bool = True,
|
||||||
|
fallback_loader: str = "text_loader",
|
||||||
|
enable_dependency_validation: bool = True,
|
||||||
|
) -> LoaderEngine:
|
||||||
|
"""
|
||||||
|
Create loader engine with given configuration.
|
||||||
|
|
||||||
|
Follows cognee's pattern for engine creation functions used
|
||||||
|
in database adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_directories: Directories to search for loader implementations
|
||||||
|
default_loader_priority: Priority order for loader selection
|
||||||
|
auto_discover: Whether to auto-discover loaders from directories
|
||||||
|
fallback_loader: Default loader to use when no other matches
|
||||||
|
enable_dependency_validation: Whether to validate loader dependencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured LoaderEngine instance
|
||||||
|
"""
|
||||||
|
engine = LoaderEngine(
|
||||||
|
loader_directories=loader_directories,
|
||||||
|
default_loader_priority=default_loader_priority,
|
||||||
|
fallback_loader=fallback_loader,
|
||||||
|
enable_dependency_validation=enable_dependency_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register supported loaders from registry
|
||||||
|
for loader_name, loader_class in supported_loaders.items():
|
||||||
|
try:
|
||||||
|
loader_instance = loader_class()
|
||||||
|
engine.register_loader(loader_instance)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - allow engine to continue with other loaders
|
||||||
|
engine.logger.warning(f"Failed to register loader {loader_name}: {e}")
|
||||||
|
|
||||||
|
# Auto-discover loaders if enabled
|
||||||
|
if auto_discover:
|
||||||
|
engine.discover_loaders()
|
||||||
|
|
||||||
|
return engine
|
||||||
20
infrastructure/loaders/get_loader_engine.py
Normal file
20
infrastructure/loaders/get_loader_engine.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
from .config import get_loader_config
|
||||||
|
from .LoaderEngine import LoaderEngine
|
||||||
|
from .create_loader_engine import create_loader_engine
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_loader_engine() -> LoaderEngine:
|
||||||
|
"""
|
||||||
|
Factory function to get loader engine.
|
||||||
|
|
||||||
|
Follows cognee's pattern with @lru_cache for efficient reuse
|
||||||
|
of engine instances. Configuration is loaded from environment
|
||||||
|
variables and settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached LoaderEngine instance configured with current settings
|
||||||
|
"""
|
||||||
|
config = get_loader_config()
|
||||||
|
return create_loader_engine(**config.to_dict())
|
||||||
47
infrastructure/loaders/models/LoaderResult.py
Normal file
47
infrastructure/loaders/models/LoaderResult.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(Enum):
|
||||||
|
"""Content type classification for loaded files"""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
STRUCTURED = "structured"
|
||||||
|
BINARY = "binary"
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderResult(BaseModel):
|
||||||
|
"""
|
||||||
|
Standardized output format for all file loaders.
|
||||||
|
|
||||||
|
This model ensures consistent data structure across all loader implementations,
|
||||||
|
following cognee's pattern of using Pydantic models for data validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str # Primary text content extracted from file
|
||||||
|
metadata: Dict[str, Any] # File metadata (name, size, type, loader info, etc.)
|
||||||
|
content_type: ContentType # Content classification
|
||||||
|
chunks: Optional[List[str]] = None # Pre-chunked content if available
|
||||||
|
source_info: Optional[Dict[str, Any]] = None # Source-specific information
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert the loader result to a dictionary format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing all loader result data with string-serialized content_type
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"content_type": self.content_type.value,
|
||||||
|
"source_info": self.source_info or {},
|
||||||
|
"chunks": self.chunks,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration following cognee patterns"""
|
||||||
|
|
||||||
|
use_enum_values = True
|
||||||
|
validate_assignment = True
|
||||||
3
infrastructure/loaders/models/__init__.py
Normal file
3
infrastructure/loaders/models/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
__all__ = ["LoaderResult", "ContentType"]
|
||||||
3
infrastructure/loaders/supported_loaders.py
Normal file
3
infrastructure/loaders/supported_loaders.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Registry for loader implementations
|
||||||
|
# Follows cognee's pattern used in databases/vector/supported_databases.py
|
||||||
|
supported_loaders = {}
|
||||||
22
infrastructure/loaders/use_loader.py
Normal file
22
infrastructure/loaders/use_loader.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
from .supported_loaders import supported_loaders
|
||||||
|
|
||||||
|
|
||||||
|
def use_loader(loader_name: str, loader_class):
|
||||||
|
"""
|
||||||
|
Register a loader at runtime.
|
||||||
|
|
||||||
|
Follows cognee's pattern used in databases for adapter registration.
|
||||||
|
This allows external packages and custom loaders to be registered
|
||||||
|
into the loader system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader_name: Unique name for the loader
|
||||||
|
loader_class: Loader class implementing LoaderInterface
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from cognee.infrastructure.loaders import use_loader
|
||||||
|
from my_package import MyCustomLoader
|
||||||
|
|
||||||
|
use_loader("my_custom_loader", MyCustomLoader)
|
||||||
|
"""
|
||||||
|
supported_loaders[loader_name] = loader_class
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Tests package
|
||||||
1
tests/unit/__init__.py
Normal file
1
tests/unit/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Unit tests package
|
||||||
1
tests/unit/infrastructure/__init__.py
Normal file
1
tests/unit/infrastructure/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Infrastructure tests package
|
||||||
1
tests/unit/infrastructure/loaders/__init__.py
Normal file
1
tests/unit/infrastructure/loaders/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Loaders tests package
|
||||||
252
tests/unit/infrastructure/loaders/test_loader_engine.py
Normal file
252
tests/unit/infrastructure/loaders/test_loader_engine.py
Normal file
|
|
@ -0,0 +1,252 @@
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, AsyncMock
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.LoaderEngine import LoaderEngine
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
|
||||||
|
class MockLoader(LoaderInterface):
|
||||||
|
"""Mock loader for testing."""
|
||||||
|
|
||||||
|
def __init__(self, name="mock_loader", extensions=None, mime_types=None, fail_deps=False):
|
||||||
|
self._name = name
|
||||||
|
self._extensions = extensions or [".mock"]
|
||||||
|
self._mime_types = mime_types or ["application/mock"]
|
||||||
|
self._fail_deps = fail_deps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self):
|
||||||
|
return self._extensions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self):
|
||||||
|
return self._mime_types
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
return ext in self._extensions or mime_type in self._mime_types
|
||||||
|
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
return LoaderResult(
|
||||||
|
content=f"Mock content from {self._name}",
|
||||||
|
metadata={"loader": self._name, "name": os.path.basename(file_path)},
|
||||||
|
content_type=ContentType.TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_dependencies(self) -> bool:
|
||||||
|
return not self._fail_deps
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoaderEngine:
|
||||||
|
"""Test the LoaderEngine class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self):
|
||||||
|
"""Create a LoaderEngine instance for testing."""
|
||||||
|
return LoaderEngine(
|
||||||
|
loader_directories=[],
|
||||||
|
default_loader_priority=["loader1", "loader2"],
|
||||||
|
fallback_loader="fallback",
|
||||||
|
enable_dependency_validation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_engine_initialization(self, engine):
|
||||||
|
"""Test LoaderEngine initialization."""
|
||||||
|
assert engine.loader_directories == []
|
||||||
|
assert engine.default_loader_priority == ["loader1", "loader2"]
|
||||||
|
assert engine.fallback_loader == "fallback"
|
||||||
|
assert engine.enable_dependency_validation is True
|
||||||
|
assert len(engine.get_available_loaders()) == 0
|
||||||
|
|
||||||
|
def test_register_loader_success(self, engine):
|
||||||
|
"""Test successful loader registration."""
|
||||||
|
loader = MockLoader("test_loader", [".test"])
|
||||||
|
|
||||||
|
success = engine.register_loader(loader)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert "test_loader" in engine.get_available_loaders()
|
||||||
|
assert engine._loaders["test_loader"] == loader
|
||||||
|
assert ".test" in engine._extension_map
|
||||||
|
assert "application/mock" in engine._mime_type_map
|
||||||
|
|
||||||
|
def test_register_loader_with_failed_dependencies(self, engine):
|
||||||
|
"""Test loader registration with failed dependency validation."""
|
||||||
|
loader = MockLoader("test_loader", [".test"], fail_deps=True)
|
||||||
|
|
||||||
|
success = engine.register_loader(loader)
|
||||||
|
|
||||||
|
assert success is False
|
||||||
|
assert "test_loader" not in engine.get_available_loaders()
|
||||||
|
|
||||||
|
def test_register_loader_without_dependency_validation(self):
|
||||||
|
"""Test loader registration without dependency validation."""
|
||||||
|
engine = LoaderEngine(
|
||||||
|
loader_directories=[], default_loader_priority=[], enable_dependency_validation=False
|
||||||
|
)
|
||||||
|
loader = MockLoader("test_loader", [".test"], fail_deps=True)
|
||||||
|
|
||||||
|
success = engine.register_loader(loader)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert "test_loader" in engine.get_available_loaders()
|
||||||
|
|
||||||
|
def test_get_loader_by_extension(self, engine):
|
||||||
|
"""Test getting loader by file extension."""
|
||||||
|
loader1 = MockLoader("loader1", [".txt"])
|
||||||
|
loader2 = MockLoader("loader2", [".pdf"])
|
||||||
|
|
||||||
|
engine.register_loader(loader1)
|
||||||
|
engine.register_loader(loader2)
|
||||||
|
|
||||||
|
result = engine.get_loader("test.txt")
|
||||||
|
assert result == loader1
|
||||||
|
|
||||||
|
result = engine.get_loader("test.pdf")
|
||||||
|
assert result == loader2
|
||||||
|
|
||||||
|
result = engine.get_loader("test.unknown")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_loader_by_mime_type(self, engine):
|
||||||
|
"""Test getting loader by MIME type."""
|
||||||
|
loader = MockLoader("loader", [".txt"], ["text/plain"])
|
||||||
|
engine.register_loader(loader)
|
||||||
|
|
||||||
|
result = engine.get_loader("test.unknown", mime_type="text/plain")
|
||||||
|
assert result == loader
|
||||||
|
|
||||||
|
result = engine.get_loader("test.unknown", mime_type="application/pdf")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_loader_with_preferences(self, engine):
|
||||||
|
"""Test getting loader with preferred loaders."""
|
||||||
|
loader1 = MockLoader("loader1", [".txt"])
|
||||||
|
loader2 = MockLoader("loader2", [".txt"])
|
||||||
|
|
||||||
|
engine.register_loader(loader1)
|
||||||
|
engine.register_loader(loader2)
|
||||||
|
|
||||||
|
# Should get preferred loader
|
||||||
|
result = engine.get_loader("test.txt", preferred_loaders=["loader2"])
|
||||||
|
assert result == loader2
|
||||||
|
|
||||||
|
# Should fallback to first available if preferred not found
|
||||||
|
result = engine.get_loader("test.txt", preferred_loaders=["nonexistent"])
|
||||||
|
assert result in [loader1, loader2] # One of them should be returned
|
||||||
|
|
||||||
|
def test_get_loader_with_priority(self, engine):
|
||||||
|
"""Test loader selection with priority order."""
|
||||||
|
engine.default_loader_priority = ["priority_loader", "other_loader"]
|
||||||
|
|
||||||
|
priority_loader = MockLoader("priority_loader", [".txt"])
|
||||||
|
other_loader = MockLoader("other_loader", [".txt"])
|
||||||
|
|
||||||
|
# Register in reverse order
|
||||||
|
engine.register_loader(other_loader)
|
||||||
|
engine.register_loader(priority_loader)
|
||||||
|
|
||||||
|
# Should get priority loader even though other was registered first
|
||||||
|
result = engine.get_loader("test.txt")
|
||||||
|
assert result == priority_loader
|
||||||
|
|
||||||
|
def test_get_loader_fallback(self, engine):
|
||||||
|
"""Test fallback loader selection."""
|
||||||
|
fallback_loader = MockLoader("fallback", [".txt"])
|
||||||
|
other_loader = MockLoader("other", [".pdf"])
|
||||||
|
|
||||||
|
engine.register_loader(fallback_loader)
|
||||||
|
engine.register_loader(other_loader)
|
||||||
|
engine.fallback_loader = "fallback"
|
||||||
|
|
||||||
|
# For .txt file, fallback should be considered
|
||||||
|
result = engine.get_loader("test.txt")
|
||||||
|
assert result == fallback_loader
|
||||||
|
|
||||||
|
# For unknown extension, should still get fallback if it can handle
|
||||||
|
result = engine.get_loader("test.unknown")
|
||||||
|
assert result == fallback_loader
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_file_success(self, engine):
|
||||||
|
"""Test successful file loading."""
|
||||||
|
loader = MockLoader("test_loader", [".txt"])
|
||||||
|
engine.register_loader(loader)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||||
|
f.write(b"test content")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await engine.load_file(temp_path)
|
||||||
|
assert result.content == "Mock content from test_loader"
|
||||||
|
assert result.metadata["loader"] == "test_loader"
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_file_no_loader(self, engine):
|
||||||
|
"""Test file loading when no suitable loader is found."""
|
||||||
|
with pytest.raises(ValueError, match="No loader found for file"):
|
||||||
|
await engine.load_file("test.unknown")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_file_with_preferences(self, engine):
|
||||||
|
"""Test file loading with preferred loaders."""
|
||||||
|
loader1 = MockLoader("loader1", [".txt"])
|
||||||
|
loader2 = MockLoader("loader2", [".txt"])
|
||||||
|
|
||||||
|
engine.register_loader(loader1)
|
||||||
|
engine.register_loader(loader2)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||||
|
f.write(b"test content")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await engine.load_file(temp_path, preferred_loaders=["loader2"])
|
||||||
|
assert result.metadata["loader"] == "loader2"
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_get_loader_info(self, engine):
|
||||||
|
"""Test getting loader information."""
|
||||||
|
loader = MockLoader("test_loader", [".txt"], ["text/plain"])
|
||||||
|
engine.register_loader(loader)
|
||||||
|
|
||||||
|
info = engine.get_loader_info("test_loader")
|
||||||
|
|
||||||
|
assert info["name"] == "test_loader"
|
||||||
|
assert info["extensions"] == [".txt"]
|
||||||
|
assert info["mime_types"] == ["text/plain"]
|
||||||
|
assert info["available"] is True
|
||||||
|
|
||||||
|
# Test non-existent loader
|
||||||
|
info = engine.get_loader_info("nonexistent")
|
||||||
|
assert info == {}
|
||||||
|
|
||||||
|
def test_discover_loaders_empty_directory(self, engine):
|
||||||
|
"""Test loader discovery with empty directory."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
engine.loader_directories = [temp_dir]
|
||||||
|
engine.discover_loaders()
|
||||||
|
|
||||||
|
# Should not find any loaders in empty directory
|
||||||
|
assert len(engine.get_available_loaders()) == 0
|
||||||
|
|
||||||
|
def test_discover_loaders_nonexistent_directory(self, engine):
|
||||||
|
"""Test loader discovery with non-existent directory."""
|
||||||
|
engine.loader_directories = ["/nonexistent/directory"]
|
||||||
|
|
||||||
|
# Should not raise exception, just log warning
|
||||||
|
engine.discover_loaders()
|
||||||
|
assert len(engine.get_available_loaders()) == 0
|
||||||
99
tests/unit/infrastructure/loaders/test_loader_interface.py
Normal file
99
tests/unit/infrastructure/loaders/test_loader_interface.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import LoaderResult, ContentType
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoaderInterface:
|
||||||
|
"""Test the LoaderInterface abstract base class."""
|
||||||
|
|
||||||
|
def test_loader_interface_is_abstract(self):
|
||||||
|
"""Test that LoaderInterface cannot be instantiated directly."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
LoaderInterface()
|
||||||
|
|
||||||
|
def test_dependency_validation_with_no_dependencies(self):
|
||||||
|
"""Test dependency validation when no dependencies are required."""
|
||||||
|
|
||||||
|
class MockLoader(LoaderInterface):
|
||||||
|
@property
|
||||||
|
def supported_extensions(self):
|
||||||
|
return [".txt"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self):
|
||||||
|
return ["text/plain"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self):
|
||||||
|
return "mock_loader"
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||||
|
|
||||||
|
loader = MockLoader()
|
||||||
|
assert loader.validate_dependencies() is True
|
||||||
|
assert loader.get_dependencies() == []
|
||||||
|
|
||||||
|
def test_dependency_validation_with_missing_dependencies(self):
|
||||||
|
"""Test dependency validation with missing dependencies."""
|
||||||
|
|
||||||
|
class MockLoaderWithDeps(LoaderInterface):
|
||||||
|
@property
|
||||||
|
def supported_extensions(self):
|
||||||
|
return [".txt"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self):
|
||||||
|
return ["text/plain"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self):
|
||||||
|
return "mock_loader_deps"
|
||||||
|
|
||||||
|
def get_dependencies(self):
|
||||||
|
return ["non_existent_package>=1.0.0"]
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||||
|
|
||||||
|
loader = MockLoaderWithDeps()
|
||||||
|
assert loader.validate_dependencies() is False
|
||||||
|
assert "non_existent_package>=1.0.0" in loader.get_dependencies()
|
||||||
|
|
||||||
|
def test_dependency_validation_with_existing_dependencies(self):
|
||||||
|
"""Test dependency validation with existing dependencies."""
|
||||||
|
|
||||||
|
class MockLoaderWithExistingDeps(LoaderInterface):
|
||||||
|
@property
|
||||||
|
def supported_extensions(self):
|
||||||
|
return [".txt"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self):
|
||||||
|
return ["text/plain"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self):
|
||||||
|
return "mock_loader_existing"
|
||||||
|
|
||||||
|
def get_dependencies(self):
|
||||||
|
return ["os"] # Built-in module that always exists
|
||||||
|
|
||||||
|
def can_handle(self, file_path: str, mime_type: str = None) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def load(self, file_path: str, **kwargs) -> LoaderResult:
|
||||||
|
return LoaderResult(content="test", metadata={}, content_type=ContentType.TEXT)
|
||||||
|
|
||||||
|
loader = MockLoaderWithExistingDeps()
|
||||||
|
assert loader.validate_dependencies() is True
|
||||||
157
tests/unit/infrastructure/loaders/test_text_loader.py
Normal file
157
tests/unit/infrastructure/loaders/test_text_loader.py
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.core.text_loader import TextLoader
|
||||||
|
from cognee.infrastructure.loaders.models.LoaderResult import ContentType
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextLoader:
|
||||||
|
"""Test the TextLoader implementation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_loader(self):
|
||||||
|
"""Create a TextLoader instance for testing."""
|
||||||
|
return TextLoader()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_text_file(self):
|
||||||
|
"""Create a temporary text file for testing."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||||
|
f.write("This is a test file.\nIt has multiple lines.\n")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_binary_file(self):
|
||||||
|
"""Create a temporary binary file for testing."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".bin", delete=False) as f:
|
||||||
|
f.write(b"\x00\x01\x02\x03\x04\x05")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_loader_properties(self, text_loader):
|
||||||
|
"""Test basic loader properties."""
|
||||||
|
assert text_loader.loader_name == "text_loader"
|
||||||
|
assert ".txt" in text_loader.supported_extensions
|
||||||
|
assert ".md" in text_loader.supported_extensions
|
||||||
|
assert "text/plain" in text_loader.supported_mime_types
|
||||||
|
assert "application/json" in text_loader.supported_mime_types
|
||||||
|
|
||||||
|
def test_can_handle_by_extension(self, text_loader):
|
||||||
|
"""Test file handling by extension."""
|
||||||
|
assert text_loader.can_handle("test.txt")
|
||||||
|
assert text_loader.can_handle("test.md")
|
||||||
|
assert text_loader.can_handle("test.json")
|
||||||
|
assert text_loader.can_handle("test.TXT") # Case insensitive
|
||||||
|
assert not text_loader.can_handle("test.pdf")
|
||||||
|
|
||||||
|
def test_can_handle_by_mime_type(self, text_loader):
|
||||||
|
"""Test file handling by MIME type."""
|
||||||
|
assert text_loader.can_handle("test.unknown", mime_type="text/plain")
|
||||||
|
assert text_loader.can_handle("test.unknown", mime_type="application/json")
|
||||||
|
assert not text_loader.can_handle("test.unknown", mime_type="application/pdf")
|
||||||
|
|
||||||
|
def test_can_handle_text_file_heuristic(self, text_loader, temp_text_file):
|
||||||
|
"""Test handling of text files by content heuristic."""
|
||||||
|
# Remove extension to force heuristic check
|
||||||
|
no_ext_path = temp_text_file.replace(".txt", "")
|
||||||
|
os.rename(temp_text_file, no_ext_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert text_loader.can_handle(no_ext_path)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(no_ext_path):
|
||||||
|
os.unlink(no_ext_path)
|
||||||
|
|
||||||
|
def test_cannot_handle_binary_file(self, text_loader, temp_binary_file):
|
||||||
|
"""Test that binary files are not handled."""
|
||||||
|
assert not text_loader.can_handle(temp_binary_file)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_text_file(self, text_loader, temp_text_file):
|
||||||
|
"""Test loading a text file."""
|
||||||
|
result = await text_loader.load(temp_text_file)
|
||||||
|
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
assert "This is a test file." in result.content
|
||||||
|
assert result.content_type == ContentType.TEXT
|
||||||
|
assert result.metadata["loader"] == "text_loader"
|
||||||
|
assert result.metadata["name"] == os.path.basename(temp_text_file)
|
||||||
|
assert result.metadata["lines"] == 2
|
||||||
|
assert result.metadata["encoding"] == "utf-8"
|
||||||
|
assert result.source_info["file_path"] == temp_text_file
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_with_custom_encoding(self, text_loader):
|
||||||
|
"""Test loading with custom encoding."""
|
||||||
|
# Create a file with latin-1 encoding
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".txt", delete=False, encoding="latin-1"
|
||||||
|
) as f:
|
||||||
|
f.write("Test with åéîøü characters")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await text_loader.load(temp_path, encoding="latin-1")
|
||||||
|
assert "åéîøü" in result.content
|
||||||
|
assert result.metadata["encoding"] == "latin-1"
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_with_fallback_encoding(self, text_loader):
|
||||||
|
"""Test automatic fallback to latin-1 encoding."""
|
||||||
|
# Create a file with latin-1 content but try to read as utf-8
|
||||||
|
with tempfile.NamedTemporaryFile(mode="wb", suffix=".txt", delete=False) as f:
|
||||||
|
# Write latin-1 encoded bytes that are invalid in utf-8
|
||||||
|
f.write(b"Test with \xe5\xe9\xee\xf8\xfc characters")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Should automatically fallback to latin-1
|
||||||
|
result = await text_loader.load(temp_path)
|
||||||
|
assert result.metadata["encoding"] == "latin-1"
|
||||||
|
assert len(result.content) > 0
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_nonexistent_file(self, text_loader):
|
||||||
|
"""Test loading a file that doesn't exist."""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
await text_loader.load("/nonexistent/file.txt")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_empty_file(self, text_loader):
|
||||||
|
"""Test loading an empty file."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||||
|
# Create empty file
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await text_loader.load(temp_path)
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.metadata["lines"] == 0
|
||||||
|
assert result.metadata["characters"] == 0
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_no_dependencies(self, text_loader):
|
||||||
|
"""Test that TextLoader has no external dependencies."""
|
||||||
|
assert text_loader.get_dependencies() == []
|
||||||
|
assert text_loader.validate_dependencies() is True
|
||||||
Loading…
Add table
Reference in a new issue