From 4f07adee66b51411a09af7044e5aa5ab7de8c434 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 18 Dec 2025 16:10:05 +0100 Subject: [PATCH] chore: fixes get_raw_data endpoint and adds s3 support (#1916) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR fixes get_raw_data endpoint in get_dataset_router - Fixes local path access - Adds s3 access - Covers new fixed functionality with unit tests ## Acceptance Criteria ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **New Features** * Streaming support for remote S3 data locations so large dataset files can be retrieved efficiently. * Improved handling of local and remote file paths for downloads. * **Improvements** * Standardized error responses for missing datasets or data files. * **Tests** * Added unit tests covering local file downloads and S3 streaming, including content and attachment header verification. ✏️ Tip: You can customize this high-level summary in your review settings. --- .../datasets/routers/get_datasets_router.py | 40 +++++- .../unit/api/test_get_raw_data_endpoint.py | 136 ++++++++++++++++++ 2 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 cognee/tests/unit/api/test_get_raw_data_endpoint.py diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index 040ed14bf..afd2b2cce 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -7,7 +7,9 @@ from fastapi import status from fastapi import APIRouter from fastapi.encoders import jsonable_encoder from fastapi import HTTPException, Query, Depends -from fastapi.responses import JSONResponse, FileResponse +from fastapi.responses import JSONResponse, FileResponse, StreamingResponse +from urllib.parse import urlparse +from pathlib import Path from cognee.api.DTO import InDTO, OutDTO from cognee.infrastructure.databases.relational import get_relational_engine @@ -476,6 +478,40 @@ def get_datasets_router() -> APIRouter: message=f"Data ({data_id}) not found in dataset ({dataset_id})." ) - return data.raw_data_location + raw_location = data.raw_data_location + + if raw_location.startswith("file://"): + from cognee.infrastructure.files.utils.get_data_file_path import get_data_file_path + + raw_location = get_data_file_path(raw_location) + + if raw_location.startswith("s3://"): + from cognee.infrastructure.files.utils.open_data_file import open_data_file + from cognee.infrastructure.utils.run_async import run_async + + parsed = urlparse(raw_location) + download_name = Path(parsed.path).name or data.name + media_type = data.mime_type or "application/octet-stream" + + async def file_iterator(chunk_size: int = 1024 * 1024): + async with open_data_file(raw_location, mode="rb") as file: + while True: + chunk = await run_async(file.read, chunk_size) + if not chunk: + break + yield chunk + + return StreamingResponse( + file_iterator(), + media_type=media_type, + headers={"Content-Disposition": f'attachment; filename="{download_name}"'}, + ) + + path = Path(raw_location) + + if not path.is_file(): + raise DataNotFoundError(message=f"Raw file not found on disk for data ({data_id}).") + + return FileResponse(path=path) return router diff --git a/cognee/tests/unit/api/test_get_raw_data_endpoint.py b/cognee/tests/unit/api/test_get_raw_data_endpoint.py new file mode 100644 index 000000000..392919755 --- /dev/null +++ b/cognee/tests/unit/api/test_get_raw_data_endpoint.py @@ -0,0 +1,136 @@ +import io +import uuid +from contextlib import asynccontextmanager +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from cognee.modules.users.methods import get_authenticated_user + + +@pytest.fixture(scope="session") +def test_client(): + from cognee.api.v1.datasets.routers.get_datasets_router import get_datasets_router + + app = FastAPI() + app.include_router(get_datasets_router(), prefix="/api/v1/datasets") + + with TestClient(app) as c: + yield c + + +@pytest.fixture +def client(test_client): + async def override_get_authenticated_user(): + return SimpleNamespace( + id=str(uuid.uuid4()), + email="default@example.com", + is_active=True, + tenant_id=str(uuid.uuid4()), + ) + + import importlib + + datasets_router_module = importlib.import_module( + "cognee.api.v1.datasets.routers.get_datasets_router" + ) + datasets_router_module.send_telemetry = lambda *args, **kwargs: None + + test_client.app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user + yield test_client + test_client.app.dependency_overrides.pop(get_authenticated_user, None) + + +def _patch_raw_download_dependencies( + monkeypatch, *, dataset_id, data_id, raw_data_location, name, mime_type +): + """ + Patch the internal dataset/data lookups used by GET /datasets/{dataset_id}/data/{data_id}/raw. + Keeps the test focused on response behavior (FileResponse vs StreamingResponse). + """ + import importlib + + datasets_router_module = importlib.import_module( + "cognee.api.v1.datasets.routers.get_datasets_router" + ) + + monkeypatch.setattr( + datasets_router_module, + "get_authorized_existing_datasets", + AsyncMock(return_value=[SimpleNamespace(id=dataset_id)]), + ) + + import cognee.modules.data.methods as data_methods_module + + monkeypatch.setattr( + data_methods_module, + "get_dataset_data", + AsyncMock(return_value=[SimpleNamespace(id=data_id)]), + ) + monkeypatch.setattr( + data_methods_module, + "get_data", + AsyncMock( + return_value=SimpleNamespace( + id=data_id, + raw_data_location=raw_data_location, + name=name, + mime_type=mime_type, + ) + ), + ) + + +def test_get_raw_data_local_file_downloads_bytes(client, monkeypatch, tmp_path): + """Downloads bytes from a file:// raw_data_location.""" + dataset_id = uuid.uuid4() + data_id = uuid.uuid4() + + file_path = tmp_path / "example.txt" + content = b"hello from disk" + file_path.write_bytes(content) + + _patch_raw_download_dependencies( + monkeypatch, + dataset_id=dataset_id, + data_id=data_id, + raw_data_location=file_path.as_uri(), + name="example.txt", + mime_type="text/plain", + ) + + response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw") + assert response.status_code == 200 + assert response.content == content + + +def test_get_raw_data_s3_streams_bytes_without_s3_dependency(client, monkeypatch): + """Streams bytes from an s3:// raw_data_location (mocked).""" + dataset_id = uuid.uuid4() + data_id = uuid.uuid4() + + _patch_raw_download_dependencies( + monkeypatch, + dataset_id=dataset_id, + data_id=data_id, + raw_data_location="s3://bucket/path/to/file.txt", + name="file.txt", + mime_type="text/plain", + ) + + import cognee.infrastructure.files.utils.open_data_file as open_data_file_module + + @asynccontextmanager + async def fake_open_data_file(_file_path: str, mode: str = "rb", **_kwargs): + assert mode == "rb" + yield io.BytesIO(b"hello from s3") + + monkeypatch.setattr(open_data_file_module, "open_data_file", fake_open_data_file) + + response = client.get(f"/api/v1/datasets/{dataset_id}/data/{data_id}/raw") + assert response.status_code == 200 + assert response.content == b"hello from s3" + assert response.headers.get("content-disposition") == 'attachment; filename="file.txt"'