chore: fixes get_raw_data endpoint and adds s3 support (#1916)
<!-- .github/pull_request_template.md -->
## Description
This PR fixes get_raw_data endpoint in get_dataset_router
- Fixes local path access
- Adds s3 access
- Covers new fixed functionality with unit tests
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [x] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Streaming support for remote S3 data locations so large dataset files
can be retrieved efficiently.
* Improved handling of local and remote file paths for downloads.
* **Improvements**
* Standardized error responses for missing datasets or data files.
* **Tests**
* Added unit tests covering local file downloads and S3 streaming,
including content and attachment header verification.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
a70ce2785b
commit
4f07adee66
2 changed files with 174 additions and 2 deletions
|
|
@ -7,7 +7,9 @@ from fastapi import status
|
|||
from fastapi import APIRouter
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import HTTPException, Query, Depends
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -476,6 +478,40 @@ def get_datasets_router() -> APIRouter:
|
|||
message=f"Data ({data_id}) not found in dataset ({dataset_id})."
|
||||
)
|
||||
|
||||
return data.raw_data_location
|
||||
raw_location = data.raw_data_location
|
||||
|
||||
if raw_location.startswith("file://"):
|
||||
from cognee.infrastructure.files.utils.get_data_file_path import get_data_file_path
|
||||
|
||||
raw_location = get_data_file_path(raw_location)
|
||||
|
||||
if raw_location.startswith("s3://"):
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.utils.run_async import run_async
|
||||
|
||||
parsed = urlparse(raw_location)
|
||||
download_name = Path(parsed.path).name or data.name
|
||||
media_type = data.mime_type or "application/octet-stream"
|
||||
|
||||
async def file_iterator(chunk_size: int = 1024 * 1024):
|
||||
async with open_data_file(raw_location, mode="rb") as file:
|
||||
while True:
|
||||
chunk = await run_async(file.read, chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
file_iterator(),
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{download_name}"'},
|
||||
)
|
||||
|
||||
path = Path(raw_location)
|
||||
|
||||
if not path.is_file():
|
||||
raise DataNotFoundError(message=f"Raw file not found on disk for data ({data_id}).")
|
||||
|
||||
return FileResponse(path=path)
|
||||
|
||||
return router
|
||||
|
|
|
|||
136
cognee/tests/unit/api/test_get_raw_data_endpoint.py
Normal file
136
cognee/tests/unit/api/test_get_raw_data_endpoint.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import io
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_client():
|
||||
from cognee.api.v1.datasets.routers.get_datasets_router import get_datasets_router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(get_datasets_router(), prefix="/api/v1/datasets")
|
||||
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_client):
|
||||
async def override_get_authenticated_user():
|
||||
return SimpleNamespace(
|
||||
id=str(uuid.uuid4()),
|
||||
email="default@example.com",
|
||||
is_active=True,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
import importlib
|
||||
|
||||
datasets_router_module = importlib.import_module(
|
||||
"cognee.api.v1.datasets.routers.get_datasets_router"
|
||||
)
|
||||
datasets_router_module.send_telemetry = lambda *args, **kwargs: None
|
||||
|
||||
test_client.app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
|
||||
yield test_client
|
||||
test_client.app.dependency_overrides.pop(get_authenticated_user, None)
|
||||
|
||||
|
||||
def _patch_raw_download_dependencies(
|
||||
monkeypatch, *, dataset_id, data_id, raw_data_location, name, mime_type
|
||||
):
|
||||
"""
|
||||
Patch the internal dataset/data lookups used by GET /datasets/{dataset_id}/data/{data_id}/raw.
|
||||
Keeps the test focused on response behavior (FileResponse vs StreamingResponse).
|
||||
"""
|
||||
import importlib
|
||||
|
||||
datasets_router_module = importlib.import_module(
|
||||
"cognee.api.v1.datasets.routers.get_datasets_router"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
datasets_router_module,
|
||||
"get_authorized_existing_datasets",
|
||||
AsyncMock(return_value=[SimpleNamespace(id=dataset_id)]),
|
||||
)
|
||||
|
||||
import cognee.modules.data.methods as data_methods_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
data_methods_module,
|
||||
"get_dataset_data",
|
||||
AsyncMock(return_value=[SimpleNamespace(id=data_id)]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
data_methods_module,
|
||||
"get_data",
|
||||
AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=data_id,
|
||||
raw_data_location=raw_data_location,
|
||||
name=name,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_raw_data_local_file_downloads_bytes(client, monkeypatch, tmp_path):
|
||||
"""Downloads bytes from a file:// raw_data_location."""
|
||||
dataset_id = uuid.uuid4()
|
||||
data_id = uuid.uuid4()
|
||||
|
||||
file_path = tmp_path / "example.txt"
|
||||
content = b"hello from disk"
|
||||
file_path.write_bytes(content)
|
||||
|
||||
_patch_raw_download_dependencies(
|
||||
monkeypatch,
|
||||
dataset_id=dataset_id,
|
||||
data_id=data_id,
|
||||
raw_data_location=file_path.as_uri(),
|
||||
name="example.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
||||
assert response.status_code == 200
|
||||
assert response.content == content
|
||||
|
||||
|
||||
def test_get_raw_data_s3_streams_bytes_without_s3_dependency(client, monkeypatch):
|
||||
"""Streams bytes from an s3:// raw_data_location (mocked)."""
|
||||
dataset_id = uuid.uuid4()
|
||||
data_id = uuid.uuid4()
|
||||
|
||||
_patch_raw_download_dependencies(
|
||||
monkeypatch,
|
||||
dataset_id=dataset_id,
|
||||
data_id=data_id,
|
||||
raw_data_location="s3://bucket/path/to/file.txt",
|
||||
name="file.txt",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
|
||||
import cognee.infrastructure.files.utils.open_data_file as open_data_file_module
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_open_data_file(_file_path: str, mode: str = "rb", **_kwargs):
|
||||
assert mode == "rb"
|
||||
yield io.BytesIO(b"hello from s3")
|
||||
|
||||
monkeypatch.setattr(open_data_file_module, "open_data_file", fake_open_data_file)
|
||||
|
||||
response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw")
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"hello from s3"
|
||||
assert response.headers.get("content-disposition") == 'attachment; filename="file.txt"'
|
||||
Loading…
Add table
Reference in a new issue